Skip to content

Commit f305633

Browse files
author
Codeflash Bot
committed
potential fix
1 parent cb6df90 commit f305633

File tree

1 file changed

+56
-22
lines changed

1 file changed

+56
-22
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -278,19 +278,8 @@ def visit_Assign(self, node: ast.Assign) -> None:
278278
if isinstance(target, ast.Name):
279279
instance_mapping[target.id] = original_class
280280

281-
# Replace self.generic_visit(node) with an optimized, inlined version that
282-
# stops traversal when self.found_any_target_function is set.
283-
# This eliminates interpretive overhead of super() and function call.
284-
stack = [node]
285-
append = stack.append
286-
pop = stack.pop
287-
# found_flag = self.found_any_target_function
288-
while stack:
289-
current_node = pop()
290-
if self.found_any_target_function:
291-
break
292-
for child in ast.iter_child_nodes(current_node):
293-
append(child)
281+
# Continue visiting child nodes
282+
self.generic_visit(node)
294283

295284
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
296285
"""Handle 'from module import name' statements."""
@@ -338,11 +327,11 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
338327
if "." in target_func:
339328
class_name, method_name = target_func.split(".", 1)
340329
if aname == class_name and not alias.asname:
341-
# If an alias is used, don't match conservatively
342-
# The actual method usage should be detected in visit_Attribute
343330
self.found_any_target_function = True
344331
self.found_qualified_name = target_func
345332
return
333+
# If an alias is used, track it for later method access detection
334+
# The actual method usage will be detected in visit_Attribute
346335

347336
prefix = qname + "."
348337
# Only bother if one of the targets startswith the prefix-root
@@ -383,6 +372,14 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
383372
self.found_any_target_function = True
384373
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
385374
return
375+
# Also check if the imported name itself (without resolving alias) matches
376+
# This handles cases where the class itself is the target
377+
if imported_name in roots_possible:
378+
self.found_any_target_function = True
379+
self.found_qualified_name = self._class_method_to_target.get(
380+
(imported_name, node_attr), f"{imported_name}.{node_attr}"
381+
)
382+
return
386383

387384
# Check if this is accessing a method on an instance variable
388385
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
@@ -401,6 +398,19 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
401398

402399
self.generic_visit(node)
403400

401+
def visit_Call(self, node: ast.Call) -> None:
402+
"""Handle function calls, particularly __import__."""
403+
if self.found_any_target_function:
404+
return
405+
406+
# Check if this is a __import__ call
407+
if isinstance(node.func, ast.Name) and node.func.id == "__import__":
408+
self.has_dynamic_imports = True
409+
# When __import__ is used, any target function could potentially be imported
410+
# Be conservative and assume it might import target functions
411+
412+
self.generic_visit(node)
413+
404414
def visit_Name(self, node: ast.Name) -> None:
405415
"""Handle direct name usage like target_function()."""
406416
if self.found_any_target_function:
@@ -410,6 +420,8 @@ def visit_Name(self, node: ast.Name) -> None:
410420
if node.id == "__import__":
411421
self.has_dynamic_imports = True
412422

423+
# Check if this is a direct usage of a target function name
424+
# This catches cases like: result = target_function()
413425
if node.id in self.function_names_to_find:
414426
self.found_any_target_function = True
415427
self.found_qualified_name = node.id
@@ -444,12 +456,22 @@ def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: s
444456
except (SyntaxError, FileNotFoundError) as e:
445457
logger.debug(f"Failed to analyze imports in {test_file_path}: {e}")
446458
return True
447-
else:
448-
if analyzer.found_any_target_function:
449-
logger.debug(f"Test file {test_file_path} imports target function: {analyzer.found_qualified_name}")
450-
return True
451-
logger.debug(f"Test file {test_file_path} does not import any target functions.")
452-
return False
459+
460+
if analyzer.found_any_target_function:
461+
logger.debug(f"Test file {test_file_path} imports target function: {analyzer.found_qualified_name}")
462+
return True
463+
464+
# Be conservative with dynamic imports - if __import__ is used and a target function
465+
# is referenced, we should process the file
466+
if analyzer.has_dynamic_imports:
467+
# Check if any target function name appears as a string literal or direct usage
468+
for target_func in target_functions:
469+
if target_func in source_code:
470+
logger.debug(f"Test file {test_file_path} has dynamic imports and references {target_func}")
471+
return True
472+
473+
logger.debug(f"Test file {test_file_path} does not import any target functions.")
474+
return False
453475

454476

455477
def filter_test_files_by_imports(
@@ -663,7 +685,19 @@ def process_test_files(
663685
function_to_test_map = defaultdict(set)
664686
num_discovered_tests = 0
665687
num_discovered_replay_tests = 0
666-
jedi_project = jedi.Project(path=project_root_path)
688+
689+
# Set up sys_path for Jedi to resolve imports correctly
690+
import sys
691+
692+
jedi_sys_path = list(sys.path)
693+
# Add project root and its parent to sys_path so modules can be imported
694+
if str(project_root_path) not in jedi_sys_path:
695+
jedi_sys_path.insert(0, str(project_root_path))
696+
parent_path = project_root_path.parent
697+
if str(parent_path) not in jedi_sys_path:
698+
jedi_sys_path.insert(0, str(parent_path))
699+
700+
jedi_project = jedi.Project(path=project_root_path, sys_path=jedi_sys_path)
667701

668702
tests_cache = TestsCache(project_root_path)
669703
logger.info("!lsp|Discovering tests and processing unit tests")

0 commit comments

Comments
 (0)