Skip to content

Commit 1f6cf3f

Browse files
authored
Merge pull request #868 from codeflash-ai/import-analyser-fix
Import analyser fix
2 parents d3788ec + 1077b9a commit 1f6cf3f

File tree

8 files changed

+473
-31
lines changed

8 files changed

+473
-31
lines changed

code_to_optimize/tests/pytest/test_topological_sort.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_topological_sort():
1010
g.addEdge(2, 3)
1111
g.addEdge(3, 1)
1212

13-
assert g.topologicalSort() == [5, 4, 2, 3, 1, 0]
13+
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]
1414

1515

1616
def test_topological_sort_2():
@@ -20,15 +20,15 @@ def test_topological_sort_2():
2020
for j in range(i + 1, 10):
2121
g.addEdge(i, j)
2222

23-
assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
23+
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
2424

2525
g = Graph(10)
2626

2727
for i in range(10):
2828
for j in range(i + 1, 10):
2929
g.addEdge(i, j)
3030

31-
assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
31+
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3232

3333

3434
def test_topological_sort_3():
@@ -38,4 +38,4 @@ def test_topological_sort_3():
3838
for j in range(i + 1, 1000):
3939
g.addEdge(j, i)
4040

41-
assert g.topologicalSort() == list(reversed(range(1000)))
41+
assert g.topologicalSort()[0] == list(reversed(range(1000)))

codeflash/discovery/discover_unit_tests.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None:
212212
self.wildcard_modules: set[str] = set()
213213
# Track aliases: alias_name -> original_name
214214
self.alias_mapping: dict[str, str] = {}
215+
# Track instances: variable_name -> class_name
216+
self.instance_mapping: dict[str, str] = {}
215217

216218
# Precompute function_names for prefix search
217219
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
218220
self._exact_names = function_names_to_find
219221
self._prefix_roots: dict[str, list[str]] = {}
222+
# Precompute sets for faster lookup during visit_Attribute()
223+
self._dot_names: set[str] = set()
224+
self._dot_methods: dict[str, set[str]] = {}
225+
self._class_method_to_target: dict[tuple[str, str], str] = {}
220226
for name in function_names_to_find:
221227
if "." in name:
222-
root = name.split(".", 1)[0]
223-
self._prefix_roots.setdefault(root, []).append(name)
228+
root, method = name.rsplit(".", 1)
229+
self._dot_names.add(name)
230+
self._dot_methods.setdefault(method, set()).add(root)
231+
self._class_method_to_target[(root, method)] = name
232+
root_prefix = name.split(".", 1)[0]
233+
self._prefix_roots.setdefault(root_prefix, []).append(name)
224234

225235
def visit_Import(self, node: ast.Import) -> None:
226236
"""Handle 'import module' statements."""
@@ -247,6 +257,28 @@ def visit_Import(self, node: ast.Import) -> None:
247257
self.found_qualified_name = target_func
248258
return
249259

260+
def visit_Assign(self, node: ast.Assign) -> None:
261+
"""Track variable assignments, especially class instantiations."""
262+
if self.found_any_target_function:
263+
return
264+
265+
# Check if the assignment is a class instantiation
266+
handled_assignment = False
267+
if isinstance(node.value, ast.Call) and type(node.value.func) is ast.Name:
268+
class_name = node.value.func.id
269+
if class_name in self.imported_modules:
270+
# Track all target variables as instances of the imported class
271+
for target in node.targets:
272+
if type(target) is ast.Name:
273+
# Map the variable to the actual class name (handling aliases)
274+
original_class = self.alias_mapping.get(class_name, class_name)
275+
self.instance_mapping[target.id] = original_class
276+
handled_assignment = True
277+
278+
# Only traverse child nodes if we didn't handle a class instantiation assignment
279+
if not handled_assignment:
280+
self.generic_visit(node)
281+
250282
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
251283
"""Handle 'from module import name' statements."""
252284
if self.found_any_target_function:
@@ -287,6 +319,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
287319
self.found_qualified_name = qname
288320
return
289321

322+
# Check if any target function is a method of the imported class/module
323+
# Be conservative except when an alias is used (which requires exact method matching)
324+
for target_func in fnames:
325+
if "." in target_func:
326+
class_name, method_name = target_func.split(".", 1)
327+
if aname == class_name and not alias.asname:
328+
# If an alias is used, don't match conservatively
329+
# The actual method usage should be detected in visit_Attribute
330+
self.found_any_target_function = True
331+
self.found_qualified_name = target_func
332+
return
333+
290334
prefix = qname + "."
291335
# Only bother if one of the targets startswith the prefix-root
292336
candidates = proots.get(qname, ())
@@ -301,33 +345,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
301345
if self.found_any_target_function:
302346
return
303347

348+
# Check if this is accessing a target function through an imported module
349+
350+
node_value = node.value
351+
node_attr = node.attr
352+
304353
# Check if this is accessing a target function through an imported module
305354
if (
306-
isinstance(node.value, ast.Name)
307-
and node.value.id in self.imported_modules
308-
and node.attr in self.function_names_to_find
355+
isinstance(node_value, ast.Name)
356+
and node_value.id in self.imported_modules
357+
and node_attr in self.function_names_to_find
309358
):
310359
self.found_any_target_function = True
311-
self.found_qualified_name = node.attr
360+
self.found_qualified_name = node_attr
312361
return
313362

314-
if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
315-
for target_func in self.function_names_to_find:
316-
if "." in target_func:
317-
class_name, method_name = target_func.rsplit(".", 1)
318-
if node.attr == method_name:
319-
imported_name = node.value.id
320-
original_name = self.alias_mapping.get(imported_name, imported_name)
321-
if original_name == class_name:
322-
self.found_any_target_function = True
323-
self.found_qualified_name = target_func
324-
return
325-
326-
# Check if this is accessing a target function through a dynamically imported module
327-
# Only if we've detected dynamic imports are being used
328-
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
363+
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
364+
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
365+
roots_possible = self._dot_methods.get(node_attr)
366+
if roots_possible:
367+
imported_name = node_value.id
368+
original_name = self.alias_mapping.get(imported_name, imported_name)
369+
if original_name in roots_possible:
370+
self.found_any_target_function = True
371+
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
372+
return
373+
374+
# Check if this is accessing a method on an instance variable
375+
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
376+
class_name = self.instance_mapping[node_value.id]
377+
roots_possible = self._dot_methods.get(node_attr)
378+
if roots_possible and class_name in roots_possible:
379+
self.found_any_target_function = True
380+
self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)]
381+
return
382+
383+
# Check for dynamic import match
384+
if self.has_dynamic_imports and node_attr in self.function_names_to_find:
329385
self.found_any_target_function = True
330-
self.found_qualified_name = node.attr
386+
self.found_qualified_name = node_attr
331387
return
332388

333389
self.generic_visit(node)

tests/scripts/end_to_end_test_async.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
def run_test(expected_improvement_pct: int) -> bool:
88
config = TestConfig(
99
file_path="main.py",
10-
expected_unit_tests=0,
1110
min_improvement_x=0.1,
1211
enable_async=True,
1312
coverage_expectations=[
@@ -25,4 +24,4 @@ def run_test(expected_improvement_pct: int) -> bool:
2524

2625

2726
if __name__ == "__main__":
28-
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
27+
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))

tests/scripts/end_to_end_test_futurehouse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
def run_test(expected_improvement_pct: int) -> bool:
88
config = TestConfig(
99
file_path="src/aviary/common_tags.py",
10-
expected_unit_tests=2,
10+
expected_unit_tests=1,
1111
min_improvement_x=0.1,
1212
coverage_expectations=[
1313
CoverageExpectation(

tests/scripts/end_to_end_test_topological_sort_worktree.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def run_test(expected_improvement_pct: int) -> bool:
1818
expected_lines=[25, 26, 27, 28, 29, 30, 31],
1919
)
2020
],
21+
expected_unit_tests=1,
2122
)
2223
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
2324
return_var = run_codeflash_command(cwd, config, expected_improvement_pct)

tests/scripts/end_to_end_test_tracer_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
88
config = TestConfig(
99
trace_mode=True,
1010
min_improvement_x=0.1,
11-
expected_unit_tests=8,
11+
expected_unit_tests=0,
1212
coverage_expectations=[
1313
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[6, 7, 8, 9, 11, 14])
1414
],

tests/scripts/end_to_end_test_utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
170170
return False
171171

172172
if config.expected_unit_tests is not None:
173-
unit_test_match = re.search(r"Discovered (\d+) existing unit tests", stdout)
173+
unit_test_match = re.search(r"Discovered (\d+) existing unit test file", stdout)
174174
if not unit_test_match:
175175
logging.error("Could not find unit test count")
176176
return False

0 commit comments

Comments
 (0)