diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py index 5514a0971..db708a5c0 100644 --- a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py +++ b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py @@ -21,6 +21,45 @@ def test_threadpool() -> None: for r in result: print(r) +class AlexNet: + def __init__(self, num_classes=1000): + self.num_classes = num_classes + self.features_size = 256 * 6 * 6 + + def forward(self, x): + features = self._extract_features(x) + + output = self._classify(features) + return output + + def _extract_features(self, x): + result = [] + for i in range(len(x)): + pass + + return result + + def _classify(self, features): + total = sum(features) + return [total % self.num_classes for _ in features] + +class SimpleModel: + @staticmethod + def predict(data): + return [x * 2 for x in data] + + @classmethod + def create_default(cls): + return cls() + +def test_models(): + model = AlexNet(num_classes=10) + input_data = [1, 2, 3, 4, 5] + result = model.forward(input_data) + + model2 = SimpleModel.create_default() + prediction = model2.predict(input_data) if __name__ == "__main__": test_threadpool() + test_models() diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 6bcf8156a..0e5578c57 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -151,6 +151,15 @@ def __init__(self, function_names_to_find: set[str]) -> None: self.has_dynamic_imports: bool = False self.wildcard_modules: set[str] = set() + # Precompute function_names for prefix search + # For prefix match, store mapping from prefix-root to candidates for O(1) matching + self._exact_names = function_names_to_find + self._prefix_roots = {} + for name in function_names_to_find: + if "." in name: + root = name.split(".", 1)[0] + self._prefix_roots.setdefault(root, []).append(name) + def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" if self.found_any_target_function: @@ -181,30 +190,46 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if self.found_any_target_function: return - if not node.module: + mod = node.module + if not mod: return + fnames = self._exact_names + proots = self._prefix_roots + for alias in node.names: - if alias.name == "*": - self.wildcard_modules.add(node.module) - else: - imported_name = alias.asname if alias.asname else alias.name - self.imported_modules.add(imported_name) - - # Check for dynamic import functions - if node.module == "importlib" and alias.name == "import_module": - self.has_dynamic_imports = True - - # Check if imported name is a target qualified name - if alias.name in self.function_names_to_find: - self.found_any_target_function = True - self.found_qualified_name = alias.name - return - # Check if module.name forms a target qualified name - qualified_name = f"{node.module}.{alias.name}" - if qualified_name in self.function_names_to_find: + aname = alias.name + if aname == "*": + self.wildcard_modules.add(mod) + continue + + imported_name = alias.asname if alias.asname else aname + self.imported_modules.add(imported_name) + + # Fast check for dynamic import + if mod == "importlib" and aname == "import_module": + self.has_dynamic_imports = True + + qname = f"{mod}.{aname}" + + # Fast exact match check + if aname in fnames: + self.found_any_target_function = True + self.found_qualified_name = aname + return + if qname in fnames: + self.found_any_target_function = True + self.found_qualified_name = qname + return + + # Fast prefix match: only for relevant roots + prefix = qname + "." + # Only bother if one of the targets startswith the prefix-root + candidates = proots.get(qname, ()) + for target_func in candidates: + if target_func.startswith(prefix): self.found_any_target_function = True - self.found_qualified_name = qualified_name + self.found_qualified_name = target_func return def visit_Attribute(self, node: ast.Attribute) -> None: diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 99ffc6dfd..7c1e820cb 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( trace_mode=True, min_improvement_x=0.1, - expected_unit_tests=2, + expected_unit_tests=7, coverage_expectations=[ CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 10, 13]) ], diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 3645930b0..d0c70097e 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -206,8 +206,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p return False functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) - if not functions_traced or int(functions_traced.group(1)) != 4: - logging.error("Expected 4 traced functions") + if not functions_traced or int(functions_traced.group(1)) != 13: + logging.error("Expected 13 traced functions") return False replay_test_path = pathlib.Path(functions_traced.group(2))