Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions code_to_optimize/tests/pytest/test_topological_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_topological_sort():
g.addEdge(2, 3)
g.addEdge(3, 1)

assert g.topologicalSort() == [5, 4, 2, 3, 1, 0]
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]


def test_topological_sort_2():
Expand All @@ -20,15 +20,15 @@ def test_topological_sort_2():
for j in range(i + 1, 10):
g.addEdge(i, j)

assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

g = Graph(10)

for i in range(10):
for j in range(i + 1, 10):
g.addEdge(i, j)

assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


def test_topological_sort_3():
Expand All @@ -38,4 +38,4 @@ def test_topological_sort_3():
for j in range(i + 1, 1000):
g.addEdge(j, i)

assert g.topologicalSort() == list(reversed(range(1000)))
assert g.topologicalSort()[0] == list(reversed(range(1000)))
43 changes: 43 additions & 0 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.wildcard_modules: set[str] = set()
# Track aliases: alias_name -> original_name
self.alias_mapping: dict[str, str] = {}
# Track instances: variable_name -> class_name
self.instance_mapping: dict[str, str] = {}

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
Expand Down Expand Up @@ -247,6 +249,24 @@ def visit_Import(self, node: ast.Import) -> None:
self.found_qualified_name = target_func
return

def visit_Assign(self, node: ast.Assign) -> None:
"""Track variable assignments, especially class instantiations."""
if self.found_any_target_function:
return

# Check if the assignment is a class instantiation
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
class_name = node.value.func.id
if class_name in self.imported_modules:
# Track all target variables as instances of the imported class
for target in node.targets:
if isinstance(target, ast.Name):
# Map the variable to the actual class name (handling aliases)
original_class = self.alias_mapping.get(class_name, class_name)
self.instance_mapping[target.id] = original_class

self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handle 'from module import name' statements."""
if self.found_any_target_function:
Expand Down Expand Up @@ -287,6 +307,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.found_qualified_name = qname
return

# Check if any target function is a method of the imported class/module
# Be conservative except when an alias is used (which requires exact method matching)
for target_func in fnames:
if "." in target_func:
class_name, method_name = target_func.split(".", 1)
if aname == class_name and not alias.asname:
# If an alias is used, don't match conservatively
# The actual method usage should be detected in visit_Attribute
self.found_any_target_function = True
self.found_qualified_name = target_func
return

prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
Expand Down Expand Up @@ -323,6 +355,17 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
self.found_qualified_name = target_func
return

# Check if this is accessing a method on an instance variable
if isinstance(node.value, ast.Name) and node.value.id in self.instance_mapping:
class_name = self.instance_mapping[node.value.id]
for target_func in self.function_names_to_find:
if "." in target_func:
target_class, method_name = target_func.rsplit(".", 1)
if node.attr == method_name and class_name == target_class:
self.found_any_target_function = True
self.found_qualified_name = target_func
return

# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
Expand Down
1 change: 1 addition & 0 deletions tests/scripts/end_to_end_test_topological_sort_worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def run_test(expected_improvement_pct: int) -> bool:
expected_lines=[25, 26, 27, 28, 29, 30, 31],
)
],
expected_unit_tests=1,
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return_var = run_codeflash_command(cwd, config, expected_improvement_pct)
Expand Down
Loading
Loading