diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 8c1629986..4d91c3fd0 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.models.models import CodeOptimizationContext, FunctionParent, FunctionSource + from codeflash.models.models import CodeOptimizationContext, FunctionSource @dataclass @@ -615,23 +615,29 @@ def _analyze_imports_in_optimized_code( def find_target_node( root: ast.AST, function_to_optimize: FunctionToOptimize ) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]: - def _find(node: ast.AST, parents: list[FunctionParent]) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]: - if not parents: - for child in getattr(node, "body", []): - if ( - isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) - and child.name == function_to_optimize.function_name - ): - return child + parents = function_to_optimize.parents + node = root + for parent in parents: + # Fast loop: directly look for the matching ClassDef in node.body + body = getattr(node, "body", None) + if not body: return None - - parent = parents[0] - for child in getattr(node, "body", []): + for child in body: if isinstance(child, ast.ClassDef) and child.name == parent.name: - return _find(child, parents[1:]) - return None + node = child + break + else: + return None - return _find(root, function_to_optimize.parents) + # Now node is either the root or the target parent class; look for function + body = getattr(node, "body", None) + if not body: + return None + target_name = function_to_optimize.function_name + for child in body: + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name: + return child + return None def detect_unused_helper_functions(