|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import ast |
| 4 | +from collections.abc import Iterable |
4 | 5 | from pathlib import Path |
5 | 6 | from typing import TYPE_CHECKING |
6 | 7 |
|
|
9 | 10 | from codeflash.cli_cmds.console import logger |
10 | 11 | from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path |
11 | 12 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
12 | | -from codeflash.models.models import FunctionParent, TestingMode, VerificationType |
| 13 | +from codeflash.models.models import CodePosition, FunctionParent, TestingMode, VerificationType |
13 | 14 |
|
14 | 15 | if TYPE_CHECKING: |
15 | 16 | from collections.abc import Iterable |
@@ -64,62 +65,99 @@ def __init__( |
64 | 65 | self.module_path = module_path |
65 | 66 | self.test_framework = test_framework |
66 | 67 | self.call_positions = call_positions |
| 68 | + # Pre-cache node wrappers often instantiated |
| 69 | + self.ast_codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load()) |
| 70 | + self.ast_codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load()) |
| 71 | + self.ast_codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) |
67 | 72 | if len(function.parents) == 1 and function.parents[0].type == "ClassDef": |
68 | 73 | self.class_name = function.top_level_parent_name |
69 | 74 |
|
70 | 75 | def find_and_update_line_node( |
71 | 76 | self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None |
72 | 77 | ) -> Iterable[ast.stmt] | None: |
| 78 | + # Optimize: Inline self._in_call_position and cache .func once |
73 | 79 | call_node = None |
| 80 | + behavior_mode = self.mode == TestingMode.BEHAVIOR |
| 81 | + function_object_name = self.function_object.function_name |
| 82 | + function_qualified_name = self.function_object.qualified_name |
| 83 | + module_path_const = ast.Constant(value=self.module_path) |
| 84 | + test_class_const = ast.Constant(value=test_class_name or None) |
| 85 | + node_name_const = ast.Constant(value=node_name) |
| 86 | + qualified_name_const = ast.Constant(value=function_qualified_name) |
| 87 | + index_const = ast.Constant(value=index) |
| 88 | + args_behavior = [self.ast_codeflash_cur, self.ast_codeflash_con] if behavior_mode else [] |
| 89 | + |
74 | 90 | for node in ast.walk(test_node): |
75 | | - if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions): |
76 | | - call_node = node |
77 | | - if isinstance(node.func, ast.Name): |
78 | | - function_name = node.func.id |
| 91 | + # Fast path: check for Call nodes only |
| 92 | + if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")): |
| 93 | + continue |
| 94 | + # Inline node_in_call_position logic (from profiler hotspot) |
| 95 | + node_lineno = getattr(node, "lineno", None) |
| 96 | + node_col_offset = getattr(node, "col_offset", None) |
| 97 | + node_end_lineno = getattr(node, "end_lineno", None) |
| 98 | + node_end_col_offset = getattr(node, "end_col_offset", None) |
| 99 | + found = False |
| 100 | + for pos in self.call_positions: |
| 101 | + pos_line = pos.line_no |
| 102 | + if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno: |
| 103 | + if pos_line == node_lineno and node_col_offset <= pos.col_no: |
| 104 | + found = True |
| 105 | + break |
| 106 | + if ( |
| 107 | + pos_line == node_end_lineno |
| 108 | + and node_end_col_offset is not None |
| 109 | + and node_end_col_offset >= pos.col_no |
| 110 | + ): |
| 111 | + found = True |
| 112 | + break |
| 113 | + if node_lineno < pos_line < node_end_lineno: |
| 114 | + found = True |
| 115 | + break |
| 116 | + if not found: |
| 117 | + continue |
| 118 | + |
| 119 | + call_node = node |
| 120 | + func = node.func |
| 121 | + # Handle ast.Name fast path |
| 122 | + if isinstance(func, ast.Name): |
| 123 | + function_name = func.id |
| 124 | + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
| 125 | + # Build ast.Name fields for use in args |
| 126 | + codeflash_func_arg = ast.Name(id=function_name, ctx=ast.Load()) |
| 127 | + # Compose argument tuple directly, for speed |
| 128 | + node.args = [ |
| 129 | + codeflash_func_arg, |
| 130 | + module_path_const, |
| 131 | + test_class_const, |
| 132 | + node_name_const, |
| 133 | + qualified_name_const, |
| 134 | + index_const, |
| 135 | + self.ast_codeflash_loop_index, |
| 136 | + *args_behavior, |
| 137 | + *call_node.args, |
| 138 | + ] |
| 139 | + node.keywords = call_node.keywords |
| 140 | + break |
| 141 | + if isinstance(func, ast.Attribute): |
| 142 | + # This path is almost never hit (profile), but handle it |
| 143 | + function_to_test = func.attr |
| 144 | + if function_to_test == function_object_name: |
| 145 | + # NOTE: ast.unparse is very slow; only call if necessary |
| 146 | + function_name = ast.unparse(func) |
79 | 147 | node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
80 | 148 | node.args = [ |
81 | 149 | ast.Name(id=function_name, ctx=ast.Load()), |
82 | | - ast.Constant(value=self.module_path), |
83 | | - ast.Constant(value=test_class_name or None), |
84 | | - ast.Constant(value=node_name), |
85 | | - ast.Constant(value=self.function_object.qualified_name), |
86 | | - ast.Constant(value=index), |
87 | | - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), |
88 | | - *( |
89 | | - [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] |
90 | | - if self.mode == TestingMode.BEHAVIOR |
91 | | - else [] |
92 | | - ), |
| 150 | + module_path_const, |
| 151 | + test_class_const, |
| 152 | + node_name_const, |
| 153 | + qualified_name_const, |
| 154 | + index_const, |
| 155 | + self.ast_codeflash_loop_index, |
| 156 | + *args_behavior, |
93 | 157 | *call_node.args, |
94 | 158 | ] |
95 | 159 | node.keywords = call_node.keywords |
96 | 160 | break |
97 | | - if isinstance(node.func, ast.Attribute): |
98 | | - function_to_test = node.func.attr |
99 | | - if function_to_test == self.function_object.function_name: |
100 | | - function_name = ast.unparse(node.func) |
101 | | - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) |
102 | | - node.args = [ |
103 | | - ast.Name(id=function_name, ctx=ast.Load()), |
104 | | - ast.Constant(value=self.module_path), |
105 | | - ast.Constant(value=test_class_name or None), |
106 | | - ast.Constant(value=node_name), |
107 | | - ast.Constant(value=self.function_object.qualified_name), |
108 | | - ast.Constant(value=index), |
109 | | - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), |
110 | | - *( |
111 | | - [ |
112 | | - ast.Name(id="codeflash_cur", ctx=ast.Load()), |
113 | | - ast.Name(id="codeflash_con", ctx=ast.Load()), |
114 | | - ] |
115 | | - if self.mode == TestingMode.BEHAVIOR |
116 | | - else [] |
117 | | - ), |
118 | | - *call_node.args, |
119 | | - ] |
120 | | - node.keywords = call_node.keywords |
121 | | - break |
122 | | - |
123 | 161 | if call_node is None: |
124 | 162 | return None |
125 | 163 | return [test_node] |
@@ -153,6 +191,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = |
153 | 191 | while j >= 0: |
154 | 192 | compound_line_node: ast.stmt = line_node.body[j] |
155 | 193 | internal_node: ast.AST |
| 194 | + # No significant hotspot here; ast.walk used on small subtrees |
156 | 195 | for internal_node in ast.walk(compound_line_node): |
157 | 196 | if isinstance(internal_node, (ast.stmt, ast.Assign)): |
158 | 197 | updated_node = self.find_and_update_line_node( |
@@ -284,6 +323,29 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = |
284 | 323 | ] |
285 | 324 | return node |
286 | 325 |
|
| 326 | + def _in_call_position(self, node: ast.AST) -> bool: |
| 327 | + # Inline node_in_call_position for performance |
| 328 | + if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")): |
| 329 | + return False |
| 330 | + node_lineno = getattr(node, "lineno", None) |
| 331 | + node_col_offset = getattr(node, "col_offset", None) |
| 332 | + node_end_lineno = getattr(node, "end_lineno", None) |
| 333 | + node_end_col_offset = getattr(node, "end_col_offset", None) |
| 334 | + for pos in self.call_positions: |
| 335 | + pos_line = pos.line_no |
| 336 | + if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno: |
| 337 | + if pos_line == node_lineno and node_col_offset <= pos.col_no: |
| 338 | + return True |
| 339 | + if ( |
| 340 | + pos_line == node_end_lineno |
| 341 | + and node_end_col_offset is not None |
| 342 | + and node_end_col_offset >= pos.col_no |
| 343 | + ): |
| 344 | + return True |
| 345 | + if node_lineno < pos_line < node_end_lineno: |
| 346 | + return True |
| 347 | + return False |
| 348 | + |
287 | 349 |
|
288 | 350 | class FunctionImportedAsVisitor(ast.NodeVisitor): |
289 | 351 | """Checks if a function has been imported as an alias. We only care about the alias then. |
|
0 commit comments