diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index ae3d82b57..b1cc8c7be 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -2,6 +2,7 @@ import ast import platform +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING @@ -20,24 +21,39 @@ from codeflash.models.models import CodePosition +@dataclass(frozen=True) +class FunctionCallNodeArguments: + args: list[ast.expr] + keywords: list[ast.keyword] + + +def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments: + return FunctionCallNodeArguments(call_node.args, call_node.keywords) + + def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool: - if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"): - for pos in call_positions: - if ( - pos.line_no is not None - and node.end_lineno is not None - and node.lineno <= pos.line_no <= node.end_lineno - ): - if pos.line_no == node.lineno and node.col_offset <= pos.col_no: - return True - if ( - pos.line_no == node.end_lineno - and node.end_col_offset is not None - and node.end_col_offset >= pos.col_no - ): - return True - if node.lineno < pos.line_no < node.end_lineno: - return True + # Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty. + # Small optimizations for tight loop: + if isinstance(node, ast.Call): + node_lineno = getattr(node, "lineno", None) + node_col_offset = getattr(node, "col_offset", None) + node_end_lineno = getattr(node, "end_lineno", None) + node_end_col_offset = getattr(node, "end_col_offset", None) + if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None: + # Faster loop: reduce attribute lookups, use local variables for conditionals. + for pos in call_positions: + pos_line = pos.line_no + if pos_line is not None and node_lineno <= pos_line <= node_end_lineno: + if pos_line == node_lineno and node_col_offset <= pos.col_no: + return True + if ( + pos_line == node_end_lineno + and node_end_col_offset is not None + and node_end_col_offset >= pos.col_no + ): + return True + if node_lineno < pos_line < node_end_lineno: + return True return False @@ -73,66 +89,231 @@ def __init__( def find_and_update_line_node( self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None ) -> Iterable[ast.stmt] | None: + # Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call, + # it's much more efficient to visit nodes manually. We'll only descend into expressions/statements. + + # Helper for manual walk + def iter_ast_calls(node): + # Generator to yield each ast.Call in test_node, preserves node identity + stack = [node] + while stack: + n = stack.pop() + if isinstance(n, ast.Call): + yield n + # Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node), + # do a specialized BFS with only the necessary attributes + for field, value in ast.iter_fields(n): + if isinstance(value, list): + for item in reversed(value): + if isinstance(item, ast.AST): + stack.append(item) + elif isinstance(value, ast.AST): + stack.append(value) + + # This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead + return_statement = [test_node] call_node = None - for node in ast.walk(test_node): - if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions): - call_node = node - if isinstance(node.func, ast.Name): - function_name = node.func.id - if self.function_object.is_async: + # Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals + fn_obj = self.function_object + module_path = self.module_path + mode = self.mode + qualified_name = fn_obj.qualified_name + + # Use locals for all 'current' values, only look up class/function/constant AST object once. + codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load()) + codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load()) + codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) + + for node in iter_ast_calls(test_node): + if not node_in_call_position(node, self.call_positions): + continue + + call_node = node + all_args = get_call_arguments(call_node) + # Two possible call types: Name and Attribute + node_func = node.func + + if isinstance(node_func, ast.Name): + function_name = node_func.id + + if fn_obj.is_async: + return [test_node] + + # Build once, reuse objects. + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) + bind_call = ast.Assign( + targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), + args=[ast.Name(id=function_name, ctx=ast.Load())], + keywords=[], + ), + attr="bind", + ctx=ast.Load(), + ), + args=all_args.args, + keywords=all_args.keywords, + ), + lineno=test_node.lineno, + col_offset=test_node.col_offset, + ) + + apply_defaults = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="apply_defaults", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + lineno=test_node.lineno + 1, + col_offset=test_node.col_offset, + ) + + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) + base_args = [ + ast.Name(id=function_name, ctx=ast.Load()), + ast.Constant(value=module_path), + ast.Constant(value=test_class_name or None), + ast.Constant(value=node_name), + ast.Constant(value=qualified_name), + ast.Constant(value=index), + codeflash_loop_index, + ] + # Extend with BEHAVIOR extras if needed + if mode == TestingMode.BEHAVIOR: + base_args += [codeflash_cur, codeflash_con] + # Extend with call args (performance) or starred bound args (behavior) + if mode == TestingMode.PERFORMANCE: + base_args += call_node.args + else: + base_args.append( + ast.Starred( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="args", + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + ) + node.args = base_args + # Prepare keywords + if mode == TestingMode.BEHAVIOR: + node.keywords = [ + ast.keyword( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="kwargs", + ctx=ast.Load(), + ) + ) + ] + else: + node.keywords = call_node.keywords + + return_statement = ( + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] + ) + break + if isinstance(node_func, ast.Attribute): + function_to_test = node_func.attr + if function_to_test == fn_obj.function_name: + if fn_obj.is_async: return [test_node] + # Create the signature binding statements + + # Unparse only once + function_name_expr = ast.parse(ast.unparse(node_func), mode="eval").body + + inspect_name = ast.Name(id="inspect", ctx=ast.Load()) + bind_call = ast.Assign( + targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()), + args=[function_name_expr], + keywords=[], + ), + attr="bind", + ctx=ast.Load(), + ), + args=all_args.args, + keywords=all_args.keywords, + ), + lineno=test_node.lineno, + col_offset=test_node.col_offset, + ) + + apply_defaults = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="apply_defaults", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + lineno=test_node.lineno + 1, + col_offset=test_node.col_offset, + ) + node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) - node.args = [ - ast.Name(id=function_name, ctx=ast.Load()), - ast.Constant(value=self.module_path), + base_args = [ + function_name_expr, + ast.Constant(value=module_path), ast.Constant(value=test_class_name or None), ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), + ast.Constant(value=qualified_name), ast.Constant(value=index), - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( - [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())] - if self.mode == TestingMode.BEHAVIOR - else [] - ), - *call_node.args, + codeflash_loop_index, ] - node.keywords = call_node.keywords - break - if isinstance(node.func, ast.Attribute): - function_to_test = node.func.attr - if function_to_test == self.function_object.function_name: - if self.function_object.is_async: - return [test_node] - - function_name = ast.unparse(node.func) - node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) - node.args = [ - ast.Name(id=function_name, ctx=ast.Load()), - ast.Constant(value=self.module_path), - ast.Constant(value=test_class_name or None), - ast.Constant(value=node_name), - ast.Constant(value=self.function_object.qualified_name), - ast.Constant(value=index), - ast.Name(id="codeflash_loop_index", ctx=ast.Load()), - *( - [ - ast.Name(id="codeflash_cur", ctx=ast.Load()), - ast.Name(id="codeflash_con", ctx=ast.Load()), - ] - if self.mode == TestingMode.BEHAVIOR - else [] - ), - *call_node.args, + if mode == TestingMode.BEHAVIOR: + base_args += [codeflash_cur, codeflash_con] + if mode == TestingMode.PERFORMANCE: + base_args += call_node.args + else: + base_args.append( + ast.Starred( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="args", + ctx=ast.Load(), + ), + ctx=ast.Load(), + ) + ) + node.args = base_args + if mode == TestingMode.BEHAVIOR: + node.keywords = [ + ast.keyword( + value=ast.Attribute( + value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()), + attr="kwargs", + ctx=ast.Load(), + ) + ) ] + else: node.keywords = call_node.keywords - break + + # Return the signature binding statements along with the test_node + return_statement = ( + [bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node] + ) + break if call_node is None: return None - return [test_node] + return return_statement def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes. @@ -593,7 +774,11 @@ def inject_profiling_into_existing_test( ] if mode == TestingMode.BEHAVIOR: new_imports.extend( - [ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])] + [ + ast.Import(names=[ast.alias(name="inspect")]), + ast.Import(names=[ast.alias(name="sqlite3")]), + ast.Import(names=[ast.alias(name="dill", asname="pickle")]), + ] ) if test_framework == "unittest" and platform.system() != "Windows": new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")])) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 398efe461..896ada442 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None: self.wildcard_modules: set[str] = set() # Track aliases: alias_name -> original_name self.alias_mapping: dict[str, str] = {} + # Track instances: variable_name -> class_name + self.instance_mapping: dict[str, str] = {} # Precompute function_names for prefix search # For prefix match, store mapping from prefix-root to candidates for O(1) matching self._exact_names = function_names_to_find self._prefix_roots: dict[str, list[str]] = {} + # Precompute sets for faster lookup during visit_Attribute() + self._dot_names: set[str] = set() + self._dot_methods: dict[str, set[str]] = {} + self._class_method_to_target: dict[tuple[str, str], str] = {} for name in function_names_to_find: if "." in name: - root = name.split(".", 1)[0] - self._prefix_roots.setdefault(root, []).append(name) + root, method = name.rsplit(".", 1) + self._dot_names.add(name) + self._dot_methods.setdefault(method, set()).add(root) + self._class_method_to_target[(root, method)] = name + root_prefix = name.split(".", 1)[0] + self._prefix_roots.setdefault(root_prefix, []).append(name) def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" @@ -247,6 +257,41 @@ def visit_Import(self, node: ast.Import) -> None: self.found_qualified_name = target_func return + def visit_Assign(self, node: ast.Assign) -> None: + """Track variable assignments, especially class instantiations.""" + if self.found_any_target_function: + return + + # Check if the assignment is a class instantiation + value = node.value + if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): + class_name = value.func.id + if class_name in self.imported_modules: + # Map the variable to the actual class name (handling aliases) + original_class = self.alias_mapping.get(class_name, class_name) + # Use list comprehension for direct assignment to instance_mapping, reducing loop overhead + targets = node.targets + instance_mapping = self.instance_mapping + # since ast.Name nodes are heavily used, avoid local lookup for isinstance + # and reuse locals for faster attribute access + for target in targets: + if isinstance(target, ast.Name): + instance_mapping[target.id] = original_class + + # Replace self.generic_visit(node) with an optimized, inlined version that + # stops traversal when self.found_any_target_function is set. + # This eliminates interpretive overhead of super() and function call. + stack = [node] + append = stack.append + pop = stack.pop + found_flag = self.found_any_target_function + while stack: + current_node = pop() + if self.found_any_target_function: + break + for child in ast.iter_child_nodes(current_node): + append(child) + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Handle 'from module import name' statements.""" if self.found_any_target_function: @@ -287,6 +332,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: self.found_qualified_name = qname return + # Check if any target function is a method of the imported class/module + # Be conservative except when an alias is used (which requires exact method matching) + for target_func in fnames: + if "." in target_func: + class_name, method_name = target_func.split(".", 1) + if aname == class_name and not alias.asname: + # If an alias is used, don't match conservatively + # The actual method usage should be detected in visit_Attribute + self.found_any_target_function = True + self.found_qualified_name = target_func + return + prefix = qname + "." # Only bother if one of the targets startswith the prefix-root candidates = proots.get(qname, ()) @@ -301,33 +358,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None: if self.found_any_target_function: return + # Check if this is accessing a target function through an imported module + + node_value = node.value + node_attr = node.attr + # Check if this is accessing a target function through an imported module if ( - isinstance(node.value, ast.Name) - and node.value.id in self.imported_modules - and node.attr in self.function_names_to_find + isinstance(node_value, ast.Name) + and node_value.id in self.imported_modules + and node_attr in self.function_names_to_find ): self.found_any_target_function = True - self.found_qualified_name = node.attr + self.found_qualified_name = node_attr return - if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules: - for target_func in self.function_names_to_find: - if "." in target_func: - class_name, method_name = target_func.rsplit(".", 1) - if node.attr == method_name: - imported_name = node.value.id - original_name = self.alias_mapping.get(imported_name, imported_name) - if original_name == class_name: - self.found_any_target_function = True - self.found_qualified_name = target_func - return - - # Check if this is accessing a target function through a dynamically imported module - # Only if we've detected dynamic imports are being used - if self.has_dynamic_imports and node.attr in self.function_names_to_find: + # Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target + if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules: + roots_possible = self._dot_methods.get(node_attr) + if roots_possible: + imported_name = node_value.id + original_name = self.alias_mapping.get(imported_name, imported_name) + if original_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)] + return + + # Check if this is accessing a method on an instance variable + if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping: + class_name = self.instance_mapping[node_value.id] + roots_possible = self._dot_methods.get(node_attr) + if roots_possible and class_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)] + return + + # Check for dynamic import match + if self.has_dynamic_imports and node_attr in self.function_names_to_find: self.found_any_target_function = True - self.found_qualified_name = node.attr + self.found_qualified_name = node_attr return self.generic_visit(node) diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index 8ad1dc870..ae9d5cda6 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -62,6 +62,7 @@ def test_sort(): expected = ( """import gc +import inspect import os import sqlite3 import time @@ -81,10 +82,14 @@ def test_sort(): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ @@ -242,6 +247,7 @@ def test_sort(): expected = ( """import gc +import inspect import os import sqlite3 import time @@ -262,11 +268,15 @@ def test_sort(): codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] sort_class = BubbleSorter() - output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sort_class.sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] sort_class = BubbleSorter() - output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '6', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sort_class.sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '6', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index ad972d7e9..e2189be01 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -91,6 +91,7 @@ def build_expected_unittest_imports(extra_imports: str = "") -> str: imports = """import gc +import inspect import os import sqlite3 import time @@ -140,6 +141,7 @@ def test_sort(self): self.assertEqual(sorter(input), list(range(5000))) """ imports = """import gc +import inspect import os import sqlite3 import time @@ -148,7 +150,7 @@ def test_sort(self): import dill as pickle""" if platform.system() != "Windows": imports += "\nimport timeout_decorator" - + imports += "\n\nfrom code_to_optimize.bubble_sort import sorter" wrapper_func = codeflash_wrap_string @@ -166,13 +168,19 @@ def test_sort(self): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) input = list(reversed(range(5000))) - self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input), list(range(5000))) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs), list(range(5000))) codeflash_con.close() """ @@ -211,6 +219,7 @@ def test_prepare_image_for_yolo(): assert compare_results(return_val_1, ret) """ expected = """import gc +import inspect import os import sqlite3 import time @@ -272,7 +281,9 @@ def test_prepare_image_for_yolo(): """ expected += """ args = pickle.loads(arg_val_pkl) return_val_1 = pickle.loads(return_val_pkl) - ret = codeflash_wrap(packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo, '{module_path}', None, 'test_prepare_image_for_yolo', 'packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '0_2', codeflash_loop_index, codeflash_cur, codeflash_con, **args) + _call__bound__arguments = inspect.signature(packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo).bind(**args) + _call__bound__arguments.apply_defaults() + ret = codeflash_wrap(packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo, '{module_path}', None, 'test_prepare_image_for_yolo', 'packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '0_2', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert compare_results(return_val_1, ret) codeflash_con.close() """ @@ -312,6 +323,7 @@ def test_sort(): expected = ( """import datetime import gc +import inspect import os import sqlite3 import time @@ -332,10 +344,14 @@ def test_sort(): codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] print(datetime.datetime.now().isoformat()) - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '5', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '5', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ @@ -572,6 +588,7 @@ def test_sort_parametrized(input, expected_output): """ expected = ( """import gc +import inspect import os import sqlite3 import time @@ -592,7 +609,9 @@ def test_sort_parametrized(input, expected_output): codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == expected_output codeflash_con.close() """ @@ -841,6 +860,7 @@ def test_sort_parametrized_loop(input, expected_output): """ expected = ( """import gc +import inspect import os import sqlite3 import time @@ -862,7 +882,9 @@ def test_sort_parametrized_loop(input, expected_output): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') for i in range(2): - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized_loop', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized_loop', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == expected_output codeflash_con.close() """ @@ -1194,6 +1216,7 @@ def test_sort(): expected = ( """import gc +import inspect import os import sqlite3 import time @@ -1217,7 +1240,9 @@ def test_sort(): for i in range(3): input = inputs[i] expected_output = expected_outputs[i] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == expected_output codeflash_con.close() """ @@ -1483,6 +1508,7 @@ def test_sort(self): if is_windows: expected = ( """import gc +import inspect import os import sqlite3 import time @@ -1505,13 +1531,19 @@ def test_sort(self): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) input = list(reversed(range(50))) - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, list(range(50))) codeflash_con.close() """ @@ -1546,6 +1578,7 @@ def test_sort(self): else: expected = ( """import gc +import inspect import os import sqlite3 import time @@ -1570,13 +1603,19 @@ def test_sort(self): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) input = list(reversed(range(50))) - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, list(range(50))) codeflash_con.close() """ @@ -1839,7 +1878,9 @@ def test_sort(self, input, expected_output): codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, expected_output) codeflash_con.close() """ @@ -2092,11 +2133,13 @@ def test_sort(self): for i in range(3): input = inputs[i] expected_output = expected_outputs[i] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, expected_output) codeflash_con.close() """ - + expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior # Build expected perf output with platform-aware imports @@ -2349,11 +2392,13 @@ def test_sort(self, input, expected_output): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') for i in range(2): - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) self.assertEqual(output, expected_output) codeflash_con.close() """ - + expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior # Build expected perf output with platform-aware imports imports_perf = """import gc @@ -2668,6 +2713,7 @@ def test_class_name_A_function_name(): expected = ( """import gc +import inspect import os import sqlite3 import time @@ -2685,7 +2731,9 @@ def test_class_name_A_function_name(): codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') - ret = codeflash_wrap(class_name_A.function_name, '{module_path}', None, 'test_class_name_A_function_name', 'class_name_A.function_name', '0', codeflash_loop_index, codeflash_cur, codeflash_con, **args) + _call__bound__arguments = inspect.signature(class_name_A.function_name).bind(**args) + _call__bound__arguments.apply_defaults() + ret = codeflash_wrap(class_name_A.function_name, '{module_path}', None, 'test_class_name_A_function_name', 'class_name_A.function_name', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) codeflash_con.close() """ ) @@ -2736,6 +2784,7 @@ def test_common_tags_1(): expected = ( """import gc +import inspect import os import sqlite3 import time @@ -2755,9 +2804,13 @@ def test_common_tags_1(): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') articles_1 = [1, 2, 3] - assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '1', codeflash_loop_index, codeflash_cur, codeflash_con, articles_1) == set(1, 2) + _call__bound__arguments = inspect.signature(find_common_tags).bind(articles_1) + _call__bound__arguments.apply_defaults() + assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) == set(1, 2) articles_2 = [1, 2] - assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '3', codeflash_loop_index, codeflash_cur, codeflash_con, articles_2) == set(1) + _call__bound__arguments = inspect.signature(find_common_tags).bind(articles_2) + _call__bound__arguments.apply_defaults() + assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '3', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) == set(1) codeflash_con.close() """ ) @@ -2803,6 +2856,7 @@ def test_sort(): expected = ( """import gc +import inspect import os import sqlite3 import time @@ -2823,7 +2877,9 @@ def test_sort(): codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] if len(input) > 0: - assert codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1_0', codeflash_loop_index, codeflash_cur, codeflash_con, input) == [0, 1, 2, 3, 4, 5] + _call__bound__arguments = inspect.signature(sorter).bind(input) + _call__bound__arguments.apply_defaults() + assert codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1_0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) == [0, 1, 2, 3, 4, 5] codeflash_con.close() """ ) @@ -2870,6 +2926,7 @@ def test_sort(): expected = ( """import gc +import inspect import os import sqlite3 import time @@ -2889,10 +2946,14 @@ def test_sort(): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(BubbleSorter.sorter, 'tests.pytest.test_perfinjector_bubble_sort_results_temp', None, 'test_sort', 'BubbleSorter.sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(BubbleSorter.sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(BubbleSorter.sorter, 'tests.pytest.test_perfinjector_bubble_sort_results_temp', None, 'test_sort', 'BubbleSorter.sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(BubbleSorter.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + _call__bound__arguments = inspect.signature(BubbleSorter.sorter).bind(input) + _call__bound__arguments.apply_defaults() + output = codeflash_wrap(BubbleSorter.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ @@ -2972,6 +3033,7 @@ def test_code_replacement10() -> None: expected = ( """import gc +import inspect import os import sqlite3 import time @@ -2996,9 +3058,13 @@ def test_code_replacement10() -> None: func_top_optimize = FunctionToOptimize(function_name='main_method', file_path=str(file_path), parents=[FunctionParent('MainClass', 'ClassDef')]) with open(file_path) as f: original_code = f.read() - code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code).unwrap() + _call__bound__arguments = inspect.signature(opt.get_code_optimization_context).bind(function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code) + _call__bound__arguments.apply_defaults() + code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs).unwrap() assert code_context.testgen_context_code == get_code_output - code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code) + _call__bound__arguments = inspect.signature(opt.get_code_optimization_context).bind(function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code) + _call__bound__arguments.apply_defaults() + code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) assert code_context.testgen_context_code == get_code_output codeflash_con.close() """