diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 0de8ade7..33f263fa 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -68,8 +68,9 @@ def _process_function_def( j = len(line_node.body) - 1 while j >= 0: compound_line_node: ast.stmt = line_node.body[j] - internal_node: ast.AST - for internal_node in ast.walk(compound_line_node): + nodes_to_check = [compound_line_node] + nodes_to_check.extend(getattr(compound_line_node, "body", [])) + for internal_node in nodes_to_check: if isinstance(internal_node, (ast.stmt, ast.Assign)): inv_id = str(i) + "_" + str(j) match_key = key + "#" + inv_id