Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions code_to_optimize/tests/pytest/test_topological_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_topological_sort():
g.addEdge(2, 3)
g.addEdge(3, 1)

assert g.topologicalSort() == [5, 4, 2, 3, 1, 0]
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]


def test_topological_sort_2():
Expand All @@ -20,15 +20,15 @@ def test_topological_sort_2():
for j in range(i + 1, 10):
g.addEdge(i, j)

assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

g = Graph(10)

for i in range(10):
for j in range(i + 1, 10):
g.addEdge(i, j)

assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


def test_topological_sort_3():
Expand All @@ -38,4 +38,4 @@ def test_topological_sort_3():
for j in range(i + 1, 1000):
g.addEdge(j, i)

assert g.topologicalSort() == list(reversed(range(1000)))
assert g.topologicalSort()[0] == list(reversed(range(1000)))
96 changes: 74 additions & 22 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.wildcard_modules: set[str] = set()
# Track aliases: alias_name -> original_name
self.alias_mapping: dict[str, str] = {}
# Track instances: variable_name -> class_name
self.instance_mapping: dict[str, str] = {}

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots: dict[str, list[str]] = {}
# Precompute sets for faster lookup during visit_Attribute()
self._dot_names: set[str] = set()
self._dot_methods: dict[str, set[str]] = {}
self._class_method_to_target: dict[tuple[str, str], str] = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
self._prefix_roots.setdefault(root, []).append(name)
root, method = name.rsplit(".", 1)
self._dot_names.add(name)
self._dot_methods.setdefault(method, set()).add(root)
self._class_method_to_target[(root, method)] = name
root_prefix = name.split(".", 1)[0]
self._prefix_roots.setdefault(root_prefix, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
Expand All @@ -247,6 +257,24 @@ def visit_Import(self, node: ast.Import) -> None:
self.found_qualified_name = target_func
return

def visit_Assign(self, node: ast.Assign) -> None:
"""Track variable assignments, especially class instantiations."""
if self.found_any_target_function:
return

# Check if the assignment is a class instantiation
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
class_name = node.value.func.id
if class_name in self.imported_modules:
# Track all target variables as instances of the imported class
for target in node.targets:
if isinstance(target, ast.Name):
# Map the variable to the actual class name (handling aliases)
original_class = self.alias_mapping.get(class_name, class_name)
self.instance_mapping[target.id] = original_class

self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handle 'from module import name' statements."""
if self.found_any_target_function:
Expand Down Expand Up @@ -287,6 +315,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.found_qualified_name = qname
return

# Check if any target function is a method of the imported class/module
# Be conservative except when an alias is used (which requires exact method matching)
for target_func in fnames:
if "." in target_func:
class_name, method_name = target_func.split(".", 1)
if aname == class_name and not alias.asname:
# If an alias is used, don't match conservatively
# The actual method usage should be detected in visit_Attribute
self.found_any_target_function = True
self.found_qualified_name = target_func
return

prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
Expand All @@ -301,33 +341,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
if self.found_any_target_function:
return

# Check if this is accessing a target function through an imported module

node_value = node.value
node_attr = node.attr

# Check if this is accessing a target function through an imported module
if (
isinstance(node.value, ast.Name)
and node.value.id in self.imported_modules
and node.attr in self.function_names_to_find
isinstance(node_value, ast.Name)
and node_value.id in self.imported_modules
and node_attr in self.function_names_to_find
):
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
for target_func in self.function_names_to_find:
if "." in target_func:
class_name, method_name = target_func.rsplit(".", 1)
if node.attr == method_name:
imported_name = node.value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name == class_name:
self.found_any_target_function = True
self.found_qualified_name = target_func
return

# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
roots_possible = self._dot_methods.get(node_attr)
if roots_possible:
imported_name = node_value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
return

# Check if this is accessing a method on an instance variable
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
class_name = self.instance_mapping[node_value.id]
roots_possible = self._dot_methods.get(node_attr)
if roots_possible and class_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)]
return

# Check for dynamic import match
if self.has_dynamic_imports and node_attr in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

self.generic_visit(node)
Expand Down
3 changes: 1 addition & 2 deletions tests/scripts/end_to_end_test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="main.py",
expected_unit_tests=0,
min_improvement_x=0.1,
enable_async=True,
coverage_expectations=[
Expand All @@ -25,4 +24,4 @@ def run_test(expected_improvement_pct: int) -> bool:


if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_futurehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="src/aviary/common_tags.py",
expected_unit_tests=2,
expected_unit_tests=1,
min_improvement_x=0.1,
coverage_expectations=[
CoverageExpectation(
Expand Down
1 change: 1 addition & 0 deletions tests/scripts/end_to_end_test_topological_sort_worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def run_test(expected_improvement_pct: int) -> bool:
expected_lines=[25, 26, 27, 28, 29, 30, 31],
)
],
expected_unit_tests=1,
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return_var = run_codeflash_command(cwd, config, expected_improvement_pct)
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_tracer_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
trace_mode=True,
min_improvement_x=0.1,
expected_unit_tests=8,
expected_unit_tests=1,
coverage_expectations=[
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[6, 7, 8, 9, 11, 14])
],
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/end_to_end_test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
return False

if config.expected_unit_tests is not None:
unit_test_match = re.search(r"Discovered (\d+) existing unit tests", stdout)
unit_test_match = re.search(r"Discovered (\d+) existing unit test file", stdout)
if not unit_test_match:
logging.error("Could not find unit test count")
return False
Expand Down
Loading
Loading