@@ -304,22 +304,29 @@ def filter_test_files_by_imports(
304304def discover_unit_tests (
305305 cfg : TestConfig ,
306306 discover_only_these_tests : list [Path ] | None = None ,
307- functions_to_optimize : list [FunctionToOptimize ] | None = None ,
307+ file_to_funcs_to_optimize : dict [ Path , list [FunctionToOptimize ] ] | None = None ,
308308) -> dict [str , list [FunctionCalledInTest ]]:
309309 framework_strategies : dict [str , Callable ] = {"pytest" : discover_tests_pytest , "unittest" : discover_tests_unittest }
310310 strategy = framework_strategies .get (cfg .test_framework , None )
311311 if not strategy :
312312 error_message = f"Unsupported test framework: { cfg .test_framework } "
313313 raise ValueError (error_message )
314314
315+ # Extract all functions to optimize for import filtering
316+ functions_to_optimize = None
317+ if file_to_funcs_to_optimize :
318+ functions_to_optimize = [
319+ func for funcs_list in file_to_funcs_to_optimize .values () for func in funcs_list
320+ ]
321+
315322 return strategy (cfg , discover_only_these_tests , functions_to_optimize )
316323
317324
318325def discover_tests_pytest (
319326 cfg : TestConfig ,
320327 discover_only_these_tests : list [Path ] | None = None ,
321328 functions_to_optimize : list [FunctionToOptimize ] | None = None ,
322- ) -> dict [Path , list [FunctionCalledInTest ]]:
329+ ) -> dict [str , list [FunctionCalledInTest ]]:
323330 tests_root = cfg .tests_root
324331 project_root = cfg .project_root_path
325332
@@ -395,7 +402,7 @@ def discover_tests_unittest(
395402 cfg : TestConfig ,
396403 discover_only_these_tests : list [str ] | None = None ,
397404 functions_to_optimize : list [FunctionToOptimize ] | None = None ,
398- ) -> dict [Path , list [FunctionCalledInTest ]]:
405+ ) -> dict [str , list [FunctionCalledInTest ]]:
399406 tests_root : Path = cfg .tests_root
400407 loader : unittest .TestLoader = unittest .TestLoader ()
401408 tests : unittest .TestSuite = loader .discover (str (tests_root ))
0 commit comments