Skip to content

Commit 66c75a4

Browse files
committed
bug fix
1 parent 5298a68 commit 66c75a4

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@ def test_threadpool() -> None:
2121
for r in result:
2222
print(r)
2323

24+
class AlexNet:
25+
def __init__(self, num_classes=1000):
26+
self.num_classes = num_classes
27+
self.features_size = 256 * 6 * 6
28+
29+
def forward(self, x):
30+
features = self._extract_features(x)
31+
32+
output = self._classify(features)
33+
return output
34+
35+
def _extract_features(self, x):
36+
result = []
37+
for i in range(len(x)):
38+
pass
39+
40+
return result
41+
42+
def _classify(self, features):
43+
total = sum(features)
44+
return [total % self.num_classes for _ in features]
45+
46+
class SimpleModel:
47+
@staticmethod
48+
def predict(data):
49+
return [x * 2 for x in data]
50+
51+
@classmethod
52+
def create_default(cls):
53+
return cls()
54+
55+
def test_models():
56+
model = AlexNet(num_classes=10)
57+
input_data = [1, 2, 3, 4, 5]
58+
result = model.forward(input_data)
59+
60+
model2 = SimpleModel.create_default()
61+
prediction = model2.predict(input_data)
2462

2563
if __name__ == "__main__":
2664
test_threadpool()
65+
test_models()

codeflash/discovery/discover_unit_tests.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,25 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
195195
if node.module == "importlib" and alias.name == "import_module":
196196
self.has_dynamic_imports = True
197197

198-
# Check if imported name is a target qualified name
199-
if alias.name in self.function_names_to_find:
200-
self.found_any_target_function = True
201-
self.found_qualified_name = alias.name
202-
return
203-
# Check if module.name forms a target qualified name
204-
qualified_name = f"{node.module}.{alias.name}"
205-
if qualified_name in self.function_names_to_find:
206-
self.found_any_target_function = True
207-
self.found_qualified_name = qualified_name
208-
return
198+
qualified_name = f"{node.module}.{alias.name}"
199+
potential_matches = {alias.name, qualified_name}
200+
201+
if any(name in self.function_names_to_find for name in potential_matches):
202+
self.found_any_target_function = True
203+
self.found_qualified_name = next(
204+
name for name in potential_matches if name in self.function_names_to_find
205+
)
206+
return
207+
208+
qualified_prefix = qualified_name + "."
209+
if any(target_func.startswith(qualified_prefix) for target_func in self.function_names_to_find):
210+
self.found_any_target_function = True
211+
self.found_qualified_name = next(
212+
target_func
213+
for target_func in self.function_names_to_find
214+
if target_func.startswith(qualified_prefix)
215+
)
216+
return
209217

210218
def visit_Attribute(self, node: ast.Attribute) -> None:
211219
"""Handle attribute access like module.function_name."""

tests/scripts/end_to_end_test_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
206206
return False
207207

208208
functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout)
209-
if not functions_traced or int(functions_traced.group(1)) != 4:
210-
logging.error("Expected 4 traced functions")
209+
if not functions_traced or int(functions_traced.group(1)) != 13:
210+
logging.error("Expected 13 traced functions")
211211
return False
212212

213213
replay_test_path = pathlib.Path(functions_traced.group(2))

0 commit comments

Comments
 (0)