Skip to content

Commit 0e0f916

Browse files
committed
Update discover_unit_tests.py
1 parent 236f14c commit 0e0f916

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,13 @@ def visit_Call(self, node: ast.Call) -> None:
182182
# __import__("module_name")
183183
if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
184184
self.imported_modules.add(node.args[0].value)
185-
elif (isinstance(node.func, ast.Attribute)
186-
and isinstance(node.func.value, ast.Name)
187-
and node.func.value.id == "importlib"
188-
and node.func.attr == "import_module"
189-
and node.args):
185+
elif (
186+
isinstance(node.func, ast.Attribute)
187+
and isinstance(node.func.value, ast.Name)
188+
and node.func.value.id == "importlib"
189+
and node.func.attr == "import_module"
190+
and node.args
191+
):
190192
# importlib.import_module("module_name")
191193
if isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
192194
self.imported_modules.add(node.args[0].value)
@@ -263,8 +265,7 @@ def analyze_imports_in_test_file(test_file_path: Path, target_functions: set[str
263265

264266

265267
def filter_test_files_by_imports(
266-
file_to_test_map: dict[Path, list[TestsInFile]],
267-
target_functions: set[str]
268+
file_to_test_map: dict[Path, list[TestsInFile]], target_functions: set[str]
268269
) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]:
269270
"""Filter test files based on import analysis to reduce Jedi processing.
270271
@@ -297,7 +298,9 @@ def filter_test_files_by_imports(
297298

298299

299300
def discover_unit_tests(
300-
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None
301+
cfg: TestConfig,
302+
discover_only_these_tests: list[Path] | None = None,
303+
functions_to_optimize: list[FunctionToOptimize] | None = None,
301304
) -> dict[str, list[FunctionCalledInTest]]:
302305
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
303306
strategy = framework_strategies.get(cfg.test_framework, None)
@@ -309,7 +312,9 @@ def discover_unit_tests(
309312

310313

311314
def discover_tests_pytest(
312-
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None
315+
cfg: TestConfig,
316+
discover_only_these_tests: list[Path] | None = None,
317+
functions_to_optimize: list[FunctionToOptimize] | None = None,
313318
) -> dict[Path, list[FunctionCalledInTest]]:
314319
tests_root = cfg.tests_root
315320
project_root = cfg.project_root_path
@@ -383,7 +388,9 @@ def discover_tests_pytest(
383388

384389

385390
def discover_tests_unittest(
386-
cfg: TestConfig, discover_only_these_tests: list[str] | None = None, functions_to_optimize: list[FunctionToOptimize] | None = None
391+
cfg: TestConfig,
392+
discover_only_these_tests: list[str] | None = None,
393+
functions_to_optimize: list[FunctionToOptimize] | None = None,
387394
) -> dict[Path, list[FunctionCalledInTest]]:
388395
tests_root: Path = cfg.tests_root
389396
loader: unittest.TestLoader = unittest.TestLoader()
@@ -448,7 +455,9 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
448455

449456

450457
def process_test_files(
451-
file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig, functions_to_optimize: list[FunctionToOptimize] | None = None
458+
file_to_test_map: dict[Path, list[TestsInFile]],
459+
cfg: TestConfig,
460+
functions_to_optimize: list[FunctionToOptimize] | None = None,
452461
) -> dict[str, list[FunctionCalledInTest]]:
453462
import jedi
454463

@@ -466,7 +475,7 @@ def process_test_files(
466475
# Also add qualified name without module
467476
if func.parents:
468477
target_function_names.add(f"{func.parents[0].name}.{func.function_name}")
469-
478+
470479
logger.debug(f"Target functions for import filtering: {target_function_names}")
471480
file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names)
472481
logger.debug(f"Import analysis results: {len(import_results)} files analyzed")

0 commit comments

Comments
 (0)