Skip to content

Commit ef010e5

Browse files
authored
Merge branch 'main' into fix-test-reporting
2 parents 66381bd + a8a591b commit ef010e5

File tree

5 files changed

+89
-25
lines changed

5 files changed

+89
-25
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,45 @@ def test_threadpool() -> None:
2121
for r in result:
2222
print(r)
2323

24+
class AlexNet:
25+
def __init__(self, num_classes=1000):
26+
self.num_classes = num_classes
27+
self.features_size = 256 * 6 * 6
28+
29+
def forward(self, x):
30+
features = self._extract_features(x)
31+
32+
output = self._classify(features)
33+
return output
34+
35+
def _extract_features(self, x):
36+
result = []
37+
for i in range(len(x)):
38+
pass
39+
40+
return result
41+
42+
def _classify(self, features):
43+
total = sum(features)
44+
return [total % self.num_classes for _ in features]
45+
46+
class SimpleModel:
47+
@staticmethod
48+
def predict(data):
49+
return [x * 2 for x in data]
50+
51+
@classmethod
52+
def create_default(cls):
53+
return cls()
54+
55+
def test_models():
56+
model = AlexNet(num_classes=10)
57+
input_data = [1, 2, 3, 4, 5]
58+
result = model.forward(input_data)
59+
60+
model2 = SimpleModel.create_default()
61+
prediction = model2.predict(input_data)
2462

2563
if __name__ == "__main__":
2664
test_threadpool()
65+
test_models()

codeflash/discovery/discover_unit_tests.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ def __init__(self, function_names_to_find: set[str]) -> None:
151151
self.has_dynamic_imports: bool = False
152152
self.wildcard_modules: set[str] = set()
153153

154+
# Precompute function_names for prefix search
155+
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
156+
self._exact_names = function_names_to_find
157+
self._prefix_roots = {}
158+
for name in function_names_to_find:
159+
if "." in name:
160+
root = name.split(".", 1)[0]
161+
self._prefix_roots.setdefault(root, []).append(name)
162+
154163
def visit_Import(self, node: ast.Import) -> None:
155164
"""Handle 'import module' statements."""
156165
if self.found_any_target_function:
@@ -181,30 +190,46 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
181190
if self.found_any_target_function:
182191
return
183192

184-
if not node.module:
193+
mod = node.module
194+
if not mod:
185195
return
186196

197+
fnames = self._exact_names
198+
proots = self._prefix_roots
199+
187200
for alias in node.names:
188-
if alias.name == "*":
189-
self.wildcard_modules.add(node.module)
190-
else:
191-
imported_name = alias.asname if alias.asname else alias.name
192-
self.imported_modules.add(imported_name)
193-
194-
# Check for dynamic import functions
195-
if node.module == "importlib" and alias.name == "import_module":
196-
self.has_dynamic_imports = True
197-
198-
# Check if imported name is a target qualified name
199-
if alias.name in self.function_names_to_find:
200-
self.found_any_target_function = True
201-
self.found_qualified_name = alias.name
202-
return
203-
# Check if module.name forms a target qualified name
204-
qualified_name = f"{node.module}.{alias.name}"
205-
if qualified_name in self.function_names_to_find:
201+
aname = alias.name
202+
if aname == "*":
203+
self.wildcard_modules.add(mod)
204+
continue
205+
206+
imported_name = alias.asname if alias.asname else aname
207+
self.imported_modules.add(imported_name)
208+
209+
# Fast check for dynamic import
210+
if mod == "importlib" and aname == "import_module":
211+
self.has_dynamic_imports = True
212+
213+
qname = f"{mod}.{aname}"
214+
215+
# Fast exact match check
216+
if aname in fnames:
217+
self.found_any_target_function = True
218+
self.found_qualified_name = aname
219+
return
220+
if qname in fnames:
221+
self.found_any_target_function = True
222+
self.found_qualified_name = qname
223+
return
224+
225+
# Fast prefix match: only for relevant roots
226+
prefix = qname + "."
227+
# Only bother if one of the targets startswith the prefix-root
228+
candidates = proots.get(qname, ())
229+
for target_func in candidates:
230+
if target_func.startswith(prefix):
206231
self.found_any_target_function = True
207-
self.found_qualified_name = qualified_name
232+
self.found_qualified_name = target_func
208233
return
209234

210235
def visit_Attribute(self, node: ast.Attribute) -> None:

codeflash/tracer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def __exit__(
263263
if self.function_count[
264264
str(function.file_name)
265265
+ ":"
266-
+ (function.class_name + ":" if function.class_name else "")
266+
+ (function.class_name + "." if function.class_name else "")
267267
+ function.function_name
268268
]
269269
> 0
@@ -353,7 +353,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
353353
return
354354
if function_qualified_name not in self.function_count:
355355
# seeing this function for the first time
356-
self.function_count[function_qualified_name] = 0
356+
self.function_count[function_qualified_name] = 1
357357
file_valid = filter_files_optimized(
358358
file_path=file_name,
359359
tests_root=Path(self.config["tests_root"]),

tests/scripts/end_to_end_test_tracer_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
88
config = TestConfig(
99
trace_mode=True,
1010
min_improvement_x=0.1,
11-
expected_unit_tests=1,
11+
expected_unit_tests=7,
1212
coverage_expectations=[
1313
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 10, 13])
1414
],

tests/scripts/end_to_end_test_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p
206206
return False
207207

208208
functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout)
209-
if not functions_traced or int(functions_traced.group(1)) != 4:
210-
logging.error("Expected 4 traced functions")
209+
if not functions_traced or int(functions_traced.group(1)) != 13:
210+
logging.error("Expected 13 traced functions")
211211
return False
212212

213213
replay_test_path = pathlib.Path(functions_traced.group(2))

0 commit comments

Comments
 (0)