diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 3057e923a..b1cc8c7be 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -32,23 +32,28 @@ def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments: def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool: - if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"): - for pos in call_positions: - if ( - pos.line_no is not None - and node.end_lineno is not None - and node.lineno <= pos.line_no <= node.end_lineno - ): - if pos.line_no == node.lineno and node.col_offset <= pos.col_no: - return True - if ( - pos.line_no == node.end_lineno - and node.end_col_offset is not None - and node.end_col_offset >= pos.col_no - ): - return True - if node.lineno < pos.line_no < node.end_lineno: - return True + # Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty. + # Small optimizations for tight loop: + if isinstance(node, ast.Call): + node_lineno = getattr(node, "lineno", None) + node_col_offset = getattr(node, "col_offset", None) + node_end_lineno = getattr(node, "end_lineno", None) + node_end_col_offset = getattr(node, "end_col_offset", None) + if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None: + # Faster loop: reduce attribute lookups, use local variables for conditionals. + for pos in call_positions: + pos_line = pos.line_no + if pos_line is not None and node_lineno <= pos_line <= node_end_lineno: + if pos_line == node_lineno and node_col_offset <= pos.col_no: + return True + if ( + pos_line == node_end_lineno + and node_end_col_offset is not None + and node_end_col_offset >= pos.col_no + ): + return True + if node_lineno < pos_line < node_end_lineno: + return True return False @@ -84,28 +89,157 @@ def __init__( def find_and_update_line_node( self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None ) -> Iterable[ast.stmt] | None: + # Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call, + # it's much more efficient to visit nodes manually. We'll only descend into expressions/statements. + + # Helper for manual walk + def iter_ast_calls(node): + # Generator to yield each ast.Call in test_node, preserves node identity + stack = [node] + while stack: + n = stack.pop() + if isinstance(n, ast.Call): + yield n + # Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node), + # do a specialized BFS with only the necessary attributes + for field, value in ast.iter_fields(n): + if isinstance(value, list): + for item in reversed(value): + if isinstance(item, ast.AST): + stack.append(item) + elif isinstance(value, ast.AST): + stack.append(value) + + # This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead return_statement = [test_node] call_node = None - for node in ast.walk(test_node): - if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions): - call_node = node - all_args = get_call_arguments(call_node) - if isinstance(node.func, ast.Name): - function_name = node.func.id - - if self.function_object.is_async: + + # Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals + fn_obj = self.function_object + module_path = self.module_path + mode = self.mode + qualified_name = fn_obj.qualified_name + + # Use locals for all 'current' values, only look up class/function/constant AST object once. + codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load()) + codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load()) + codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) + + for node in iter_ast_calls(test_node): + if not node_in_call_position(node, self.call_positions): + continue + + call_node = node + all_args = get_call_arguments(call_node) + # Two possible call types: Name and Attribute + node_func = node.func + + if isinstance(node_func, ast.Name): + function_name = node_func.id + + if fn_obj.is_async: + return [test_node] + + # Build once, reuse objects. + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) + bind_call = ast.Assign( + targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), + args=[ast.Name(id=function_name, ctx=ast.Load())], + keywords=[], + ), + attr="bind", + ctx=ast.Load(), + ), + args=all_args.args, + keywords=all_args.keywords, + ), + lineno=test_node.lineno, + col_offset=test_node.col_offset, + ) + + apply_defaults = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="apply_defaults", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + lineno=test_node.lineno + 1, + col_offset=test_node.col_offset, + ) + + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) + base_args = [ + ast.Name(id=function_name, ctx=ast.Load()), + ast.Constant(value=module_path), + ast.Constant(value=test_class_name or None), + ast.Constant(value=node_name), + ast.Constant(value=qualified_name), + ast.Constant(value=index), + codeflash_loop_index, + ] + # Extend with BEHAVIOR extras if needed + if mode == TestingMode.BEHAVIOR: + base_args += [codeflash_cur, codeflash_con] + # Extend with call args (performance) or starred bound args (behavior) + if mode == TestingMode.PERFORMANCE: + base_args += call_node.args + else: + base_args.append( + ast.Starred( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="args", + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + ) + node.args = base_args + # Prepare keywords + if mode == TestingMode.BEHAVIOR: + node.keywords = [ + ast.keyword( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="kwargs", + ctx=ast.Load(), + ) + ) + ] + else: + node.keywords = call_node.keywords + + return_statement = ( + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] + ) + break + if isinstance(node_func, ast.Attribute): + function_to_test = node_func.attr + if function_to_test == fn_obj.function_name: + if fn_obj.is_async: return [test_node] # Create the signature binding statements + + # Unparse only once + function_name_expr = ast.parse(ast.unparse(node_func), mode="eval").body + + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) bind_call = ast.Assign( targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], value=ast.Call( func=ast.Attribute( value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="inspect", ctx=ast.Load()), attr="signature", ctx=ast.Load() - ), - args=[ast.Name(id=function_name, ctx=ast.Load())], + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), + args=[function_name_expr], keywords=[], ), attr="bind", @@ -133,36 +267,33 @@ def find_and_update_line_node( ) node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) - node.args = [ - ast.Name(id=function_name, ctx=ast.Load()), - ast.Constant(value=self.module_path), + base_args = [ + function_name_expr, + ast.Constant(value=module_path), ast.Constant(value=test_class_name or None), ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), + ast.Constant(value=qualified_name), ast.Constant(value=index), - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( - [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] - if self.mode == TestingMode.BEHAVIOR - else [] - ), - *( - call_node.args - if self.mode == TestingMode.PERFORMANCE - else [ - ast.Starred( - value=ast.Attribute( - value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), - attr="args", - ctx=ast.Load(), - ), - ctx=ast.Load(), - ) - ] - ), + codeflash_loop_index, ] - node.keywords = ( - [ + if mode == TestingMode.BEHAVIOR: + base_args += [codeflash_cur, codeflash_con] + if mode == TestingMode.PERFORMANCE: + base_args += call_node.args + else: + base_args.append( + ast.Starred( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="args", + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + ) + node.args = base_args + if mode == TestingMode.BEHAVIOR: + node.keywords = [ ast.keyword( value=ast.Attribute( value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), @@ -171,112 +302,14 @@ def find_and_update_line_node( ) ) ] - if self.mode == TestingMode.BEHAVIOR - else call_node.keywords - ) + else: + node.keywords = call_node.keywords # Return the signature binding statements along with the test_node return_statement = ( - [bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node] + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] ) break - if isinstance(node.func, ast.Attribute): - function_to_test = node.func.attr - if function_to_test == self.function_object.function_name: - if self.function_object.is_async: - return [test_node] - - function_name = ast.unparse(node.func) - - # Create the signature binding statements - bind_call = ast.Assign( - targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="inspect", ctx=ast.Load()), - attr="signature", - ctx=ast.Load(), - ), - args=[ast.parse(function_name, mode="eval").body], - keywords=[], - ), - attr="bind", - ctx=ast.Load(), - ), - args=all_args.args, - keywords=all_args.keywords, - ), - lineno=test_node.lineno, - col_offset=test_node.col_offset, - ) - - apply_defaults = ast.Expr( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), - attr="apply_defaults", - ctx=ast.Load(), - ), - args=[], - keywords=[], - ), - lineno=test_node.lineno + 1, - col_offset=test_node.col_offset, - ) - - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) - node.args = [ - ast.parse(function_name, mode="eval").body, - ast.Constant(value=self.module_path), - ast.Constant(value=test_class_name or None), - ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), - ast.Constant(value=index), - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( - [ - ast.Name(id="codeflash_cur", ctx=ast.Load()), - ast.Name(id="codeflash_con", ctx=ast.Load()), - ] - if self.mode == TestingMode.BEHAVIOR - else [] - ), - *( - call_node.args - if self.mode == TestingMode.PERFORMANCE - else [ - ast.Starred( - value=ast.Attribute( - value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), - attr="args", - ctx=ast.Load(), - ), - ctx=ast.Load(), - ) - ] - ), - ] - node.keywords = ( - [ - ast.keyword( - value=ast.Attribute( - value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), - attr="kwargs", - ctx=ast.Load(), - ) - ) - ] - if self.mode == TestingMode.BEHAVIOR - else call_node.keywords - ) - - # Return the signature binding statements along with the test_node - return_statement = ( - [bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node] - ) - break if call_node is None: return None