diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index db08f8afc..f7967dd1a 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -19,23 +19,30 @@ def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool: - if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"): - for pos in call_positions: - if ( - pos.line_no is not None - and node.end_lineno is not None - and node.lineno <= pos.line_no <= node.end_lineno - ): - if pos.line_no == node.lineno and node.col_offset <= pos.col_no: - return True - if ( - pos.line_no == node.end_lineno - and node.end_col_offset is not None - and node.end_col_offset >= pos.col_no - ): - return True - if node.lineno < pos.line_no < node.end_lineno: - return True + # Optimize by using attribute lookup only once and avoid repeated checks inside loop + if not isinstance(node, ast.Call): + return False + + lineno = getattr(node, "lineno", None) + col_offset = getattr(node, "col_offset", None) + end_lineno = getattr(node, "end_lineno", None) + end_col_offset = getattr(node, "end_col_offset", None) + + if lineno is None or col_offset is None: + return False + + # Precompute positions as tuples for fast search. + # Typical call_positions are small, so this does not cost much in memory. + for pos in call_positions: + line_no, col_no = pos.line_no, pos.col_no + # Combine all nested checks into a single block to minimize Python opcode jumps + if line_no is not None and end_lineno is not None and lineno <= line_no <= end_lineno: + if line_no == lineno and col_offset <= col_no: + return True + if line_no == end_lineno and end_col_offset is not None and end_col_offset >= col_no: + return True + if lineno < line_no < end_lineno: + return True return False @@ -71,60 +78,76 @@ def __init__( def find_and_update_line_node( self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None ) -> Iterable[ast.stmt] | None: + # Fast local lookups to minimize attribute access and improve inner loop performance + only_func_name = self.function_object.function_name + is_async = self.function_object.is_async + qualified_name = self.function_object.qualified_name + mode = self.mode + module_path = self.module_path + call_positions = self.call_positions + + # OPT: Avoid repeated ast.walk for every found node, instead process and break immediately on match call_node = None + for node in ast.walk(test_node): - if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions): + # Fast-path block + if isinstance(node, ast.Call) and node_in_call_position(node, call_positions): call_node = node if isinstance(node.func, ast.Name): function_name = node.func.id - if self.function_object.is_async: + if is_async: return [test_node] - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) - node.args = [ + # Prepare args list outside assignment for performance (less copying) + args = [ ast.Name(id=function_name, ctx=ast.Load()), - ast.Constant(value=self.module_path), + ast.Constant(value=module_path), ast.Constant(value=test_class_name or None), ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), + ast.Constant(value=qualified_name), ast.Constant(value=index), ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( - [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] - if self.mode == TestingMode.BEHAVIOR - else [] - ), - *call_node.args, ] + + if mode == TestingMode.BEHAVIOR: + args.extend( + [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] + ) + args.extend(call_node.args) + + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) + node.args = args node.keywords = call_node.keywords break + if isinstance(node.func, ast.Attribute): - function_to_test = node.func.attr - if function_to_test == self.function_object.function_name: - if self.function_object.is_async: + if node.func.attr == only_func_name: + if is_async: return [test_node] + # Use ast.unparse only after match is confirmed function_name = ast.unparse(node.func) - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) - node.args = [ + args = [ ast.Name(id=function_name, ctx=ast.Load()), - ast.Constant(value=self.module_path), + ast.Constant(value=module_path), ast.Constant(value=test_class_name or None), ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), + ast.Constant(value=qualified_name), ast.Constant(value=index), ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( + ] + if mode == TestingMode.BEHAVIOR: + args.extend( [ ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load()), ] - if self.mode == TestingMode.BEHAVIOR - else [] - ), - *call_node.args, - ] + ) + args.extend(call_node.args) + + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) + node.args = args node.keywords = call_node.keywords break