Skip to content

Commit a936077

Browse files
committed
move things around
1 parent 2d7fd10 commit a936077

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,22 +304,29 @@ def filter_test_files_by_imports(
304304
def discover_unit_tests(
305305
cfg: TestConfig,
306306
discover_only_these_tests: list[Path] | None = None,
307-
functions_to_optimize: list[FunctionToOptimize] | None = None,
307+
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
308308
) -> dict[str, list[FunctionCalledInTest]]:
309309
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
310310
strategy = framework_strategies.get(cfg.test_framework, None)
311311
if not strategy:
312312
error_message = f"Unsupported test framework: {cfg.test_framework}"
313313
raise ValueError(error_message)
314314

315+
# Extract all functions to optimize for import filtering
316+
functions_to_optimize = None
317+
if file_to_funcs_to_optimize:
318+
functions_to_optimize = [
319+
func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list
320+
]
321+
315322
return strategy(cfg, discover_only_these_tests, functions_to_optimize)
316323

317324

318325
def discover_tests_pytest(
319326
cfg: TestConfig,
320327
discover_only_these_tests: list[Path] | None = None,
321328
functions_to_optimize: list[FunctionToOptimize] | None = None,
322-
) -> dict[Path, list[FunctionCalledInTest]]:
329+
) -> dict[str, list[FunctionCalledInTest]]:
323330
tests_root = cfg.tests_root
324331
project_root = cfg.project_root_path
325332

@@ -395,7 +402,7 @@ def discover_tests_unittest(
395402
cfg: TestConfig,
396403
discover_only_these_tests: list[str] | None = None,
397404
functions_to_optimize: list[FunctionToOptimize] | None = None,
398-
) -> dict[Path, list[FunctionCalledInTest]]:
405+
) -> dict[str, list[FunctionCalledInTest]]:
399406
tests_root: Path = cfg.tests_root
400407
loader: unittest.TestLoader = unittest.TestLoader()
401408
tests: unittest.TestSuite = loader.discover(str(tests_root))

codeflash/optimization/optimizer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,8 @@ def run(self) -> None:
162162

163163
console.rule()
164164
start_time = time.time()
165-
# Extract all functions to optimize for import filtering
166-
all_functions_to_optimize = [
167-
func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list
168-
]
169165
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(
170-
self.test_cfg, functions_to_optimize=all_functions_to_optimize
166+
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
171167
)
172168
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
173169
console.rule()

0 commit comments

Comments
 (0)