Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions code_to_optimize/code_directories/simple_tracer_e2e/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,45 @@ def test_threadpool() -> None:
for r in result:
print(r)

class AlexNet:
def __init__(self, num_classes=1000):
self.num_classes = num_classes
self.features_size = 256 * 6 * 6

def forward(self, x):
features = self._extract_features(x)

output = self._classify(features)
return output

def _extract_features(self, x):
result = []
for i in range(len(x)):
pass

return result

def _classify(self, features):
total = sum(features)
return [total % self.num_classes for _ in features]

class SimpleModel:
@staticmethod
def predict(data):
return [x * 2 for x in data]

@classmethod
def create_default(cls):
return cls()

def test_models():
model = AlexNet(num_classes=10)
input_data = [1, 2, 3, 4, 5]
result = model.forward(input_data)

model2 = SimpleModel.create_default()
prediction = model2.predict(input_data)

if __name__ == "__main__":
test_threadpool()
test_models()
65 changes: 45 additions & 20 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.has_dynamic_imports: bool = False
self.wildcard_modules: set[str] = set()

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
self._prefix_roots.setdefault(root, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
if self.found_any_target_function:
Expand Down Expand Up @@ -181,30 +190,46 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self.found_any_target_function:
return

if not node.module:
mod = node.module
if not mod:
return

fnames = self._exact_names
proots = self._prefix_roots

for alias in node.names:
if alias.name == "*":
self.wildcard_modules.add(node.module)
else:
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)

# Check for dynamic import functions
if node.module == "importlib" and alias.name == "import_module":
self.has_dynamic_imports = True

# Check if imported name is a target qualified name
if alias.name in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = alias.name
return
# Check if module.name forms a target qualified name
qualified_name = f"{node.module}.{alias.name}"
if qualified_name in self.function_names_to_find:
aname = alias.name
if aname == "*":
self.wildcard_modules.add(mod)
continue

imported_name = alias.asname if alias.asname else aname
self.imported_modules.add(imported_name)

# Fast check for dynamic import
if mod == "importlib" and aname == "import_module":
self.has_dynamic_imports = True

qname = f"{mod}.{aname}"

# Fast exact match check
if aname in fnames:
self.found_any_target_function = True
self.found_qualified_name = aname
return
if qname in fnames:
self.found_any_target_function = True
self.found_qualified_name = qname
return

# Fast prefix match: only for relevant roots
prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
for target_func in candidates:
if target_func.startswith(prefix):
self.found_any_target_function = True
self.found_qualified_name = qualified_name
self.found_qualified_name = target_func
return

def visit_Attribute(self, node: ast.Attribute) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_tracer_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
trace_mode=True,
min_improvement_x=0.1,
expected_unit_tests=2,
expected_unit_tests=7,
coverage_expectations=[
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 10, 13])
],
Expand Down
4 changes: 2 additions & 2 deletions tests/scripts/end_to_end_test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
return False

functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout)
if not functions_traced or int(functions_traced.group(1)) != 4:
logging.error("Expected 4 traced functions")
if not functions_traced or int(functions_traced.group(1)) != 13:
logging.error("Expected 13 traced functions")
return False

replay_test_path = pathlib.Path(functions_traced.group(2))
Expand Down
Loading