|
1 | 1 | import type { ContractDefinition, FunctionDefinition } from 'solidity-ast';
|
2 | 2 | import { ASTDereferencer, findAll } from 'solidity-ast/utils';
|
3 | 3 | import { SrcDecoder } from '../../src-decoder';
|
4 |
| -import { ValidationExceptionInitializer, skipCheck } from '../run'; |
| 4 | +import { ValidationExceptionInitializer, skipCheck, tryDerefFunction } from '../run'; |
5 | 5 |
|
6 | 6 | /**
|
7 | 7 | * Reports if this contract is non-abstract and any of the following are true:
|
@@ -141,6 +141,27 @@ function getParentsNotInitializedByOtherParents(
|
141 | 141 | return remainingParents;
|
142 | 142 | }
|
143 | 143 |
|
| 144 | +/** |
| 145 | + * Calls the callback if the referenced function definition is found in the AST. |
| 146 | + * Otherwise, does nothing. |
| 147 | + * |
| 148 | + * @param deref AST dereferencer |
| 149 | + * @param referencedDeclaration ID of the referenced function |
| 150 | + * @param callback Function to call if the referenced function definition is found |
| 151 | + */ |
| 152 | +function doIfReferencedFunctionFound( |
| 153 | + deref: ASTDereferencer, |
| 154 | + referencedDeclaration: number | null | undefined, |
| 155 | + callback: (functionDef: FunctionDefinition) => void, |
| 156 | +) { |
| 157 | + if (referencedDeclaration && referencedDeclaration > 0) { |
| 158 | + const functionDef = tryDerefFunction(deref, referencedDeclaration); |
| 159 | + if (functionDef !== undefined) { |
| 160 | + callback(functionDef); |
| 161 | + } |
| 162 | + } |
| 163 | +} |
| 164 | + |
144 | 165 | /**
|
145 | 166 | * Reports exceptions for missing initializer calls, duplicate initializer calls, and incorrect initializer order.
|
146 | 167 | *
|
@@ -176,10 +197,9 @@ function* getInitializerCallExceptions(
|
176 | 197 | (fnCall.expression.nodeType === 'Identifier' || fnCall.expression.nodeType === 'MemberAccess')
|
177 | 198 | ) {
|
178 | 199 | let recursiveFunctionIds: number[] = [];
|
179 |
| - const referencedFn = fnCall.expression.referencedDeclaration; |
180 |
| - if (referencedFn && referencedFn > 0) { |
181 |
| - recursiveFunctionIds = getRecursiveFunctionIds(referencedFn, deref); |
182 |
| - } |
| 200 | + doIfReferencedFunctionFound(deref, fnCall.expression.referencedDeclaration, (functionDef: FunctionDefinition) => { |
| 201 | + recursiveFunctionIds = getRecursiveFunctionIds(deref, functionDef); |
| 202 | + }); |
183 | 203 |
|
184 | 204 | // For each recursively called function, if it is a parent initializer, then:
|
185 | 205 | // - Check if it was already called (duplicate call)
|
@@ -258,38 +278,41 @@ function* getInitializerCallExceptions(
|
258 | 278 | /**
|
259 | 279 | * Gets the IDs of all functions that are recursively called by the given function, including the given function itself at the end of the list.
|
260 | 280 | *
|
261 |
| - * @param referencedFn The ID of the function to start from |
262 | 281 | * @param deref AST dereferencer
|
| 282 | + * @param functionDef The node of the function definition to start from |
263 | 283 | * @param visited Set of function IDs that have already been visited
|
264 | 284 | * @returns The IDs of all functions that are recursively called by the given function, including the given function itself at the end of the list.
|
265 | 285 | */
|
266 |
| -function getRecursiveFunctionIds(referencedFn: number, deref: ASTDereferencer, visited?: Set<number>): number[] { |
| 286 | +function getRecursiveFunctionIds( |
| 287 | + deref: ASTDereferencer, |
| 288 | + functionDef: FunctionDefinition, |
| 289 | + visited?: Set<number>, |
| 290 | +): number[] { |
267 | 291 | const result: number[] = [];
|
268 | 292 |
|
269 | 293 | if (visited === undefined) {
|
270 | 294 | visited = new Set();
|
271 | 295 | }
|
272 |
| - if (visited.has(referencedFn)) { |
| 296 | + if (visited.has(functionDef.id)) { |
273 | 297 | return result;
|
274 | 298 | } else {
|
275 |
| - visited.add(referencedFn); |
| 299 | + visited.add(functionDef.id); |
276 | 300 | }
|
277 | 301 |
|
278 |
| - const fn = deref('FunctionDefinition', referencedFn); |
279 |
| - const expressionStatements = fn.body?.statements?.filter(stmt => stmt.nodeType === 'ExpressionStatement') ?? []; |
| 302 | + const expressionStatements = |
| 303 | + functionDef.body?.statements?.filter(stmt => stmt.nodeType === 'ExpressionStatement') ?? []; |
280 | 304 | for (const stmt of expressionStatements) {
|
281 | 305 | const fnCall = stmt.expression;
|
282 | 306 | if (
|
283 | 307 | fnCall.nodeType === 'FunctionCall' &&
|
284 | 308 | (fnCall.expression.nodeType === 'Identifier' || fnCall.expression.nodeType === 'MemberAccess')
|
285 | 309 | ) {
|
286 |
| - const referencedId = fnCall.expression.referencedDeclaration; |
287 |
| - if (referencedId && referencedId > 0) { |
288 |
| - result.push(...getRecursiveFunctionIds(referencedId, deref, visited)); |
289 |
| - } |
| 310 | + doIfReferencedFunctionFound(deref, fnCall.expression.referencedDeclaration, (functionDef: FunctionDefinition) => { |
| 311 | + result.push(...getRecursiveFunctionIds(deref, functionDef, visited)); |
| 312 | + }); |
290 | 313 | }
|
291 | 314 | }
|
292 |
| - result.push(referencedFn); |
| 315 | + result.push(functionDef.id); |
293 | 316 |
|
294 | 317 | return result;
|
295 | 318 | }
|
|
0 commit comments