Skip to content
8 changes: 4 additions & 4 deletions code_to_optimize/tests/pytest/test_topological_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_topological_sort():
g.addEdge(2, 3)
g.addEdge(3, 1)

assert g.topologicalSort() == [5, 4, 2, 3, 1, 0]
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]


def test_topological_sort_2():
Expand All @@ -20,15 +20,15 @@ def test_topological_sort_2():
for j in range(i + 1, 10):
g.addEdge(i, j)

assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

g = Graph(10)

for i in range(10):
for j in range(i + 1, 10):
g.addEdge(i, j)

assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


def test_topological_sort_3():
Expand All @@ -38,4 +38,4 @@ def test_topological_sort_3():
for j in range(i + 1, 1000):
g.addEdge(j, i)

assert g.topologicalSort() == list(reversed(range(1000)))
assert g.topologicalSort()[0] == list(reversed(range(1000)))
96 changes: 74 additions & 22 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.wildcard_modules: set[str] = set()
# Track aliases: alias_name -> original_name
self.alias_mapping: dict[str, str] = {}
# Track instances: variable_name -> class_name
self.instance_mapping: dict[str, str] = {}

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots: dict[str, list[str]] = {}
# Precompute sets for faster lookup during visit_Attribute()
self._dot_names: set[str] = set()
self._dot_methods: dict[str, set[str]] = {}
self._class_method_to_target: dict[tuple[str, str], str] = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
self._prefix_roots.setdefault(root, []).append(name)
root, method = name.rsplit(".", 1)
self._dot_names.add(name)
self._dot_methods.setdefault(method, set()).add(root)
self._class_method_to_target[(root, method)] = name
root_prefix = name.split(".", 1)[0]
self._prefix_roots.setdefault(root_prefix, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
Expand All @@ -247,6 +257,24 @@ def visit_Import(self, node: ast.Import) -> None:
self.found_qualified_name = target_func
return

def visit_Assign(self, node: ast.Assign) -> None:
"""Track variable assignments, especially class instantiations."""
if self.found_any_target_function:
return

# Check if the assignment is a class instantiation
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
class_name = node.value.func.id
if class_name in self.imported_modules:
# Track all target variables as instances of the imported class
for target in node.targets:
if isinstance(target, ast.Name):
# Map the variable to the actual class name (handling aliases)
original_class = self.alias_mapping.get(class_name, class_name)
self.instance_mapping[target.id] = original_class

self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handle 'from module import name' statements."""
if self.found_any_target_function:
Expand Down Expand Up @@ -287,6 +315,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.found_qualified_name = qname
return

# Check if any target function is a method of the imported class/module
# Be conservative except when an alias is used (which requires exact method matching)
for target_func in fnames:
if "." in target_func:
class_name, method_name = target_func.split(".", 1)
if aname == class_name and not alias.asname:
# If an alias is used, don't match conservatively
# The actual method usage should be detected in visit_Attribute
self.found_any_target_function = True
self.found_qualified_name = target_func
return

prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
Expand All @@ -301,33 +341,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
if self.found_any_target_function:
return

# Check if this is accessing a target function through an imported module

node_value = node.value
node_attr = node.attr

# Check if this is accessing a target function through an imported module
if (
isinstance(node.value, ast.Name)
and node.value.id in self.imported_modules
and node.attr in self.function_names_to_find
isinstance(node_value, ast.Name)
and node_value.id in self.imported_modules
and node_attr in self.function_names_to_find
):
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
for target_func in self.function_names_to_find:
if "." in target_func:
class_name, method_name = target_func.rsplit(".", 1)
if node.attr == method_name:
imported_name = node.value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name == class_name:
self.found_any_target_function = True
self.found_qualified_name = target_func
return

# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
roots_possible = self._dot_methods.get(node_attr)
if roots_possible:
imported_name = node_value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
return

# Check if this is accessing a method on an instance variable
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
class_name = self.instance_mapping[node_value.id]
roots_possible = self._dot_methods.get(node_attr)
if roots_possible and class_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)]
return

# Check for dynamic import match
if self.has_dynamic_imports and node_attr in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

self.generic_visit(node)
Expand Down
1 change: 1 addition & 0 deletions tests/scripts/end_to_end_test_topological_sort_worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def run_test(expected_improvement_pct: int) -> bool:
expected_lines=[25, 26, 27, 28, 29, 30, 31],
)
],
expected_unit_tests=1,
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return_var = run_codeflash_command(cwd, config, expected_improvement_pct)
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
return False

if config.expected_unit_tests is not None:
unit_test_match = re.search(r"Discovered (\d+) existing unit tests", stdout)
unit_test_match = re.search(r"Discovered (\d+) existing unit test file", stdout)
if not unit_test_match:
logging.error("Could not find unit test count")
return False
Expand Down
Loading
Loading