Skip to content

Commit b0bcfb2

Browse files
author
Codeflash Bot
committed
new way
1 parent 553341a commit b0bcfb2

File tree

2 files changed

+55
-11
lines changed

2 files changed

+55
-11
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ def __init__(self, function_names_to_find: set[str]) -> None:
212212
self.wildcard_modules: set[str] = set()
213213
# Track aliases: alias_name -> original_name
214214
self.alias_mapping: dict[str, str] = {}
215+
# Track instances: variable_name -> class_name
216+
self.instance_mapping: dict[str, str] = {}
215217

216218
# Precompute function_names for prefix search
217219
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
@@ -247,6 +249,24 @@ def visit_Import(self, node: ast.Import) -> None:
247249
self.found_qualified_name = target_func
248250
return
249251

252+
def visit_Assign(self, node: ast.Assign) -> None:
253+
"""Track variable assignments, especially class instantiations."""
254+
if self.found_any_target_function:
255+
return
256+
257+
# Check if the assignment is a class instantiation
258+
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
259+
class_name = node.value.func.id
260+
if class_name in self.imported_modules:
261+
# Track all target variables as instances of the imported class
262+
for target in node.targets:
263+
if isinstance(target, ast.Name):
264+
# Map the variable to the actual class name (handling aliases)
265+
original_class = self.alias_mapping.get(class_name, class_name)
266+
self.instance_mapping[target.id] = original_class
267+
268+
self.generic_visit(node)
269+
250270
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
251271
"""Handle 'from module import name' statements."""
252272
if self.found_any_target_function:
@@ -296,17 +316,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
296316
self.found_qualified_name = target_func
297317
return
298318

299-
# Check if any target function is a method of the imported class/module
300-
# e.g., importing Graph and looking for Graph.topologicalSort
301-
# TODO will pick up all tests which have the same class name (could be coming from a different file)
302-
for target_func in fnames:
303-
if "." in target_func:
304-
class_name, method_name = target_func.split(".", 1)
305-
if aname == class_name:
306-
self.found_any_target_function = True
307-
self.found_qualified_name = target_func
308-
return
309-
310319
def visit_Attribute(self, node: ast.Attribute) -> None:
311320
"""Handle attribute access like module.function_name."""
312321
if self.found_any_target_function:
@@ -334,6 +343,17 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
334343
self.found_qualified_name = target_func
335344
return
336345

346+
# Check if this is accessing a method on an instance variable
347+
if isinstance(node.value, ast.Name) and node.value.id in self.instance_mapping:
348+
class_name = self.instance_mapping[node.value.id]
349+
for target_func in self.function_names_to_find:
350+
if "." in target_func:
351+
target_class, method_name = target_func.rsplit(".", 1)
352+
if node.attr == method_name and class_name == target_class:
353+
self.found_any_target_function = True
354+
self.found_qualified_name = target_func
355+
return
356+
337357
# Check if this is accessing a target function through a dynamically imported module
338358
# Only if we've detected dynamic imports are being used
339359
if self.has_dynamic_imports and node.attr in self.function_names_to_find:

tests/test_unit_test_discovery.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,30 @@ def test_target():
13101310

13111311
assert should_process is True
13121312

1313+
def test_analyze_imports_method():
1314+
with tempfile.TemporaryDirectory() as tmpdirname:
1315+
test_file = Path(tmpdirname) / "test_example.py"
1316+
test_content = """
1317+
from code_to_optimize.topological_sort import Graph
1318+
1319+
1320+
def test_topological_sort():
1321+
g = Graph(6)
1322+
g.addEdge(5, 2)
1323+
g.addEdge(5, 0)
1324+
g.addEdge(4, 0)
1325+
g.addEdge(4, 1)
1326+
g.addEdge(2, 3)
1327+
g.addEdge(3, 1)
1328+
1329+
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]
1330+
"""
1331+
test_file.write_text(test_content)
1332+
1333+
target_functions = {"Graph.topologicalSort"}
1334+
should_process = analyze_imports_in_test_file(test_file, target_functions)
1335+
1336+
assert should_process is True
13131337

13141338
def test_analyze_imports_aliased_class_method_negative():
13151339
with tempfile.TemporaryDirectory() as tmpdirname:

0 commit comments

Comments
 (0)