Skip to content

Commit 0086800

Browse files
committed
handle qualified names
1 parent 337b0ee commit 0086800

File tree

2 files changed

+39
-67
lines changed

2 files changed

+39
-67
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(self, function_names_to_find: set[str]) -> None:
148148
self.imported_names: set[str] = set()
149149
self.imported_modules: set[str] = set()
150150
self.found_target_functions: set[str] = set()
151+
self.qualified_names_called: set[str] = set()
151152

152153
def visit_Import(self, node: ast.Import) -> None:
153154
"""Handle 'import module' statements."""
@@ -164,13 +165,16 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
164165

165166
for alias in node.names:
166167
if alias.name == "*":
167-
# Star imports - we can't know what's imported, so be conservative
168-
self.imported_names.add("*")
169-
else:
170-
imported_name = alias.asname if alias.asname else alias.name
171-
self.imported_names.add(imported_name)
172-
if alias.name in self.function_names_to_find:
173-
self.found_target_functions.add(alias.name)
168+
continue
169+
imported_name = alias.asname if alias.asname else alias.name
170+
self.imported_names.add(imported_name)
171+
if alias.name in self.function_names_to_find:
172+
self.found_target_functions.add(alias.name)
173+
# Check for qualified name matches
174+
if node.module:
175+
qualified_name = f"{node.module}.{alias.name}"
176+
if qualified_name in self.function_names_to_find:
177+
self.found_target_functions.add(qualified_name)
174178
self.generic_visit(node)
175179

176180
def visit_Call(self, node: ast.Call) -> None:
@@ -207,6 +211,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
207211
"""Handle module.function_name patterns."""
208212
if node.attr in self.function_names_to_find:
209213
self.found_target_functions.add(node.attr)
214+
if isinstance(node.value, ast.Name):
215+
qualified_name = f"{node.value.id}.{node.attr}"
216+
self.qualified_names_called.add(qualified_name)
210217
self.generic_visit(node)
211218

212219

@@ -232,37 +239,10 @@ def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: s
232239
analyzer = ImportAnalyzer(target_functions)
233240
analyzer.visit(tree)
234241

235-
# If we found direct function matches, definitely process
236242
if analyzer.found_target_functions:
237243
return True, analyzer.found_target_functions
238244

239-
# If there are star imports, we need to be conservative
240-
if "*" in analyzer.imported_names:
241-
return True, set()
242-
243-
# Check for direct name matches first (higher priority)
244-
name_matches = analyzer.imported_names & target_functions
245-
if name_matches:
246-
return True, name_matches
247-
248-
# If no direct matches, check if any imported modules could contain our target functions
249-
# This is a heuristic - we look for common patterns
250-
potential_matches = set()
251-
for module in analyzer.imported_modules:
252-
# Check if module name suggests it could contain target functions
253-
for func_name in target_functions:
254-
# Only match if the module name is a prefix of the function qualified name
255-
func_parts = func_name.split(".")
256-
if len(func_parts) > 1 and module == func_parts[0]:
257-
# Module matches the first part of qualified name (e.g., mycode in mycode.target_function)
258-
# But only if we don't have specific import information suggesting otherwise
259-
potential_matches.add(func_name)
260-
elif any(part in module for part in func_name.split("_")) and len(func_name.split("_")) > 1:
261-
# Function name parts match module name (for underscore-separated names)
262-
potential_matches.add(func_name)
263-
264-
# Only use heuristic matches if we haven't found specific function imports that contradict them
265-
return bool(potential_matches), potential_matches
245+
return False, set() # noqa: TRY300
266246

267247
except (SyntaxError, UnicodeDecodeError, OSError) as e:
268248
logger.debug(f"Failed to analyze imports in {test_file_path}: {e}")
@@ -283,7 +263,6 @@ def filter_test_files_by_imports(
283263
284264
"""
285265
if not target_functions:
286-
# If no target functions specified, process all files
287266
return file_to_test_map, {}
288267

289268
filtered_map = {}
@@ -479,7 +458,7 @@ def process_test_files(
479458
target_function_names = set()
480459
for func in functions_to_optimize:
481460
target_function_names.add(func.qualified_name)
482-
logger.debug(f"Target functions for import filtering: {target_function_names}")
461+
logger.info(f"Target functions for import filtering: {target_function_names}")
483462
file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names)
484463
logger.debug(f"Import analysis results: {len(import_results)} files analyzed")
485464

tests/test_unit_test_discovery.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from codeflash.models.models import TestsInFile, TestType
1111
from codeflash.verification.verification_utils import TestConfig
12+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1213

1314

1415
def test_unit_test_discovery_pytest():
@@ -832,8 +833,8 @@ def test_something():
832833
target_functions = {"target_function"}
833834
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
834835

835-
assert should_process is True # Conservative approach with star imports
836-
assert found_functions == set() # No specific functions identified
836+
assert should_process is False
837+
assert found_functions == set()
837838

838839

839840
def test_analyze_imports_module_import():
@@ -907,13 +908,11 @@ def test_unrelated():
907908

908909
target_functions = {"target_function", "another_function"}
909910
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
910-
911911
assert should_process is False
912912
assert found_functions == set()
913913

914914

915-
def test_analyze_imports_heuristic_matching():
916-
"""Test heuristic module name matching."""
915+
def test_analyze_qualified_names():
917916
with tempfile.TemporaryDirectory() as tmpdirname:
918917
test_file = Path(tmpdirname) / "test_example.py"
919918
test_content = """
@@ -924,11 +923,11 @@ def test_target():
924923
"""
925924
test_file.write_text(test_content)
926925

927-
target_functions = {"target_function"} # Function name partially matches module name
926+
target_functions = {"target_module.some_function"}
928927
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
929-
930928
assert should_process is True
931-
assert "target_function" in found_functions
929+
assert "target_module.some_function" in found_functions
930+
932931

933932

934933
def test_analyze_imports_syntax_error():
@@ -952,7 +951,6 @@ def test_target(
952951

953952

954953
def test_filter_test_files_by_imports():
955-
"""Test the complete filtering functionality."""
956954
with tempfile.TemporaryDirectory() as tmpdirname:
957955
tmpdir = Path(tmpdirname)
958956

@@ -974,7 +972,7 @@ def test_other():
974972
assert other_function() is True
975973
""")
976974

977-
# Create test file with star import (should be processed)
975+
# Create test file with star import (should not be processed)
978976
star_test = tmpdir / "test_star.py"
979977
star_test.write_text("""
980978
from mymodule import *
@@ -983,7 +981,6 @@ def test_star():
983981
assert something() is True
984982
""")
985983

986-
# Build file_to_test_map
987984
file_to_test_map = {
988985
relevant_test: [TestsInFile(test_file=relevant_test, test_function="test_target", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)],
989986
irrelevant_test: [TestsInFile(test_file=irrelevant_test, test_function="test_other", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)],
@@ -993,16 +990,15 @@ def test_star():
993990
target_functions = {"target_function"}
994991
filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, target_functions)
995992

996-
# Should filter out irrelevant_test but keep relevant_test and star_test
997-
assert len(filtered_map) == 2
993+
# Should filter out irrelevant_test
994+
assert len(filtered_map) == 1
998995
assert relevant_test in filtered_map
999-
assert star_test in filtered_map
1000996
assert irrelevant_test not in filtered_map
1001997

1002998
# Check import analysis results
1003999
assert "target_function" in import_results[relevant_test]
10041000
assert len(import_results[irrelevant_test]) == 0
1005-
assert len(import_results[star_test]) == 0 # Star import doesn't identify specific functions
1001+
assert len(import_results[star_test]) == 0
10061002

10071003

10081004
def test_filter_test_files_no_target_functions():
@@ -1066,18 +1062,17 @@ def test_other():
10661062
tests_project_rootdir=tmpdir.parent,
10671063
)
10681064

1069-
# Test without filtering
10701065
all_tests, _ = discover_unit_tests(test_config)
1071-
assert len(all_tests) == 2 # Should find both functions
1072-
1073-
# Test with filtering - create mock FunctionToOptimize objects
1074-
from unittest.mock import Mock
1075-
mock_function = Mock()
1076-
mock_function.qualified_name_with_modules_from_root.return_value = "mycode.target_function"
1077-
mock_function.function_name = "target_function"
1078-
mock_function.parents = [] # No parent classes
1066+
assert len(all_tests) == 2
10791067

1080-
filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [mock_function]})
1068+
1069+
fto = FunctionToOptimize(
1070+
function_name="target_function",
1071+
file_path=code_file,
1072+
parents=[],
1073+
)
1074+
1075+
filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [fto]})
10811076
assert len(filtered_tests) >= 1
10821077
assert "mycode.target_function" in filtered_tests
10831078

@@ -1146,7 +1141,6 @@ def test_aliased():
11461141

11471142

11481143
def test_analyze_imports_underscore_function_names():
1149-
"""Test handling of function names with underscores in heuristic matching."""
11501144
with tempfile.TemporaryDirectory() as tmpdirname:
11511145
test_file = Path(tmpdirname) / "test_example.py"
11521146
test_content = """
@@ -1157,12 +1151,11 @@ def test_bubble():
11571151
"""
11581152
test_file.write_text(test_content)
11591153

1160-
target_functions = {"bubble_sort"} # Function name parts match module
1154+
target_functions = {"bubble_sort"}
11611155
should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions)
11621156

1163-
assert should_process is True
1164-
assert "bubble_sort" in found_functions
1165-
1157+
assert should_process is False
1158+
assert "bubble_sort" not in found_functions
11661159

11671160
def test_discover_unit_tests_filtering_different_modules():
11681161
"""Test import filtering with test files from completely different modules."""

0 commit comments

Comments
 (0)