22from __future__ import annotations
33
44import hashlib
5+ import multiprocessing
56import os
67import pickle
78import re
1516
1617import pytest
1718from pydantic .dataclasses import dataclass
18- from rich .panel import Panel
19- from rich .text import Text
2019
2120from codeflash .cli_cmds .console import console , logger , test_files_progress_bar
22- from codeflash .code_utils .code_utils import (
23- ImportErrorPattern ,
24- custom_addopts ,
25- get_run_tmp_file ,
26- module_name_from_file_path ,
27- )
21+ from codeflash .code_utils .code_utils import custom_addopts , get_run_tmp_file , module_name_from_file_path
2822from codeflash .code_utils .compat import SAFE_SYS_EXECUTABLE , codeflash_cache_db
2923from codeflash .models .models import CodePosition , FunctionCalledInTest , TestsInFile , TestType
3024
@@ -139,7 +133,7 @@ def close(self) -> None:
139133
140134def discover_unit_tests (
141135 cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
142- ) -> dict [str , list [FunctionCalledInTest ]]:
136+ ) -> tuple [ dict [str , list [FunctionCalledInTest ]], int ]:
143137 framework_strategies : dict [str , Callable ] = {"pytest" : discover_tests_pytest , "unittest" : discover_tests_unittest }
144138 strategy = framework_strategies .get (cfg .test_framework , None )
145139 if not strategy :
@@ -151,7 +145,7 @@ def discover_unit_tests(
151145
152146def discover_tests_pytest (
153147 cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
154- ) -> dict [Path , list [FunctionCalledInTest ]]:
148+ ) -> tuple [ dict [str , list [FunctionCalledInTest ]], int ]:
155149 tests_root = cfg .tests_root
156150 project_root = cfg .project_root_path
157151
@@ -187,10 +181,6 @@ def discover_tests_pytest(
187181 logger .warning (
188182 f"Failed to collect tests. Pytest Exit code: { exitcode } ={ pytest .ExitCode (exitcode ).name } \n { error_section } "
189183 )
190- if "ModuleNotFoundError" in result .stdout :
191- match = ImportErrorPattern .search (result .stdout ).group ()
192- panel = Panel (Text .from_markup (f"⚠️ { match } " , style = "bold red" ), expand = False )
193- console .print (panel )
194184
195185 elif 0 <= exitcode <= 5 :
196186 logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } ={ pytest .ExitCode (exitcode ).name } " )
@@ -225,7 +215,7 @@ def discover_tests_pytest(
225215
226216def discover_tests_unittest (
227217 cfg : TestConfig , discover_only_these_tests : list [str ] | None = None
228- ) -> dict [Path , list [FunctionCalledInTest ]]:
218+ ) -> tuple [ dict [str , list [FunctionCalledInTest ]], int ]:
229219 tests_root : Path = cfg .tests_root
230220 loader : unittest .TestLoader = unittest .TestLoader ()
231221 tests : unittest .TestSuite = loader .discover (str (tests_root ))
@@ -290,27 +280,39 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
290280
291281def _process_single_test_file (
292282 test_file : Path , functions : list [TestsInFile ], project_root_path : Path , test_framework : str
293- ) -> tuple [str , list [tuple [str , FunctionCalledInTest ]]]:
283+ ) -> tuple [str , list [tuple [str , FunctionCalledInTest ]], int , list [ dict ] ]:
294284 import jedi
295285
296286 jedi_project = jedi .Project (path = project_root_path )
297287 goto_cache = {}
298288 results = []
289+ cache_entries = []
299290
300291 try :
301292 script = jedi .Script (path = test_file , project = jedi_project )
302293 test_functions = set ()
303294
304295 all_names = script .get_names (all_scopes = True , references = True )
305- all_defs = script .get_names (all_scopes = True , definitions = True )
306- all_names_top = script .get_names (all_scopes = True )
307-
308- top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
309- top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
296+ top_level_functions = {}
297+ top_level_classes = {}
298+ all_defs = []
299+ reference_names = []
300+
301+ for name in all_names :
302+ if name .type == "function" :
303+ top_level_functions [name .name ] = name
304+ if hasattr (name , "full_name" ) and name .full_name :
305+ all_defs .append (name )
306+ elif name .type == "class" :
307+ top_level_classes [name .name ] = name
308+
309+ if name .full_name is not None :
310+ m = FUNCTION_NAME_REGEX .search (name .full_name )
311+ if m :
312+ reference_names .append ((name , m .group (1 )))
310313 except Exception as e :
311314 logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
312- # tests_cache.close()
313- return str (test_file ), results
315+ return str (test_file ), results , len (results ), cache_entries
314316
315317 if test_framework == "pytest" :
316318 for function in functions :
@@ -340,11 +342,8 @@ def _process_single_test_file(
340342 matching_names = test_suites & top_level_classes .keys ()
341343 for matched_name in matching_names :
342344 for def_name in all_defs :
343- if (
344- def_name .type == "function"
345- and def_name .full_name is not None
346- and f".{ matched_name } ." in def_name .full_name
347- ):
345+ # all_defs already contains only functions, no need to check type
346+ if def_name .full_name is not None and f".{ matched_name } ." in def_name .full_name :
348347 for function in functions_to_search :
349348 (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
350349
@@ -374,14 +373,7 @@ def _process_single_test_file(
374373 for i , func_name in enumerate (test_functions_raw ):
375374 test_functions_by_name [func_name ].append (i )
376375
377- for name in all_names :
378- if name .full_name is None :
379- continue
380- m = FUNCTION_NAME_REGEX .search (name .full_name )
381- if not m :
382- continue
383-
384- scope = m .group (1 )
376+ for name , scope in reference_names :
385377 if scope not in test_functions_by_name :
386378 continue
387379
@@ -432,28 +424,73 @@ def _process_single_test_file(
432424 )
433425 results .append ((qualified_name_with_modules_from_root , function_called_in_test ))
434426
435- return str (test_file ), results
427+ cache_entries .append (
428+ {
429+ "qualified_name_with_modules_from_root" : qualified_name_with_modules_from_root ,
430+ "function_name" : scope ,
431+ "test_class" : scope_test_class ,
432+ "test_function" : scope_test_function ,
433+ "test_type" : test_type ,
434+ "line_number" : name .line ,
435+ "col_number" : name .column ,
436+ }
437+ )
438+
439+ return str (test_file ), results , len (results ), cache_entries
436440
437441
438442def process_test_files (
439443 file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
440- ) -> dict [str , list [FunctionCalledInTest ]]:
444+ ) -> tuple [ dict [str , list [FunctionCalledInTest ]], int ]:
441445 project_root_path = cfg .project_root_path
442446 test_framework = cfg .test_framework
443447 function_to_test_map = defaultdict (set )
448+ total_count = 0
444449
445- import multiprocessing
450+ tests_cache = TestsCache ()
446451
447- max_workers = min (len (file_to_test_map ), multiprocessing .cpu_count ())
448- max_workers = max (1 , max_workers )
452+ max_workers = min (len (file_to_test_map ) or 1 , multiprocessing .cpu_count ())
449453
450454 with test_files_progress_bar (total = len (file_to_test_map ), description = "Processing test files" ) as (
451455 progress ,
452456 task_id ,
453457 ):
454- if len (file_to_test_map ) == 1 or max_workers == 1 :
455- for test_file , functions in file_to_test_map .items ():
456- _ , results = _process_single_test_file (test_file , functions , project_root_path , test_framework )
458+ cached_files = {}
459+ uncached_files = {}
460+
461+ for test_file , functions in file_to_test_map .items ():
462+ file_hash = TestsCache .compute_file_hash (str (test_file ))
463+ cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
464+
465+ if cached_tests :
466+ cached_files [test_file ] = (functions , cached_tests , file_hash )
467+ else :
468+ uncached_files [test_file ] = functions
469+
470+ # Process cached files first
471+ for test_file , (_functions , cached_tests , file_hash ) in cached_files .items ():
472+ cur = tests_cache .cur
473+ cur .execute (
474+ "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
475+ (str (test_file ), file_hash ),
476+ )
477+ qualified_names = [row [0 ] for row in cur .fetchall ()]
478+ for cached_test , qualified_name in zip (cached_tests , qualified_names ):
479+ function_to_test_map [qualified_name ].add (cached_test )
480+ total_count += len (cached_tests )
481+ progress .advance (task_id )
482+
483+ if len (uncached_files ) == 1 or max_workers == 1 :
484+ for test_file , functions in uncached_files .items ():
485+ _ , results , count , cache_entries = _process_single_test_file (
486+ test_file , functions , project_root_path , test_framework
487+ )
488+ total_count += count
489+
490+ file_hash = TestsCache .compute_file_hash (str (test_file ))
491+ for cache_entry in cache_entries :
492+ tests_cache .insert_test (file_path = str (test_file ), file_hash = file_hash , ** cache_entry )
493+
457494 for qualified_name , function_called in results :
458495 function_to_test_map [qualified_name ].add (function_called )
459496 progress .advance (task_id )
@@ -463,12 +500,19 @@ def process_test_files(
463500 executor .submit (
464501 _process_single_test_file , test_file , functions , project_root_path , test_framework
465502 ): test_file
466- for test_file , functions in file_to_test_map .items ()
503+ for test_file , functions in uncached_files .items ()
467504 }
468505
469506 for future in as_completed (future_to_file ):
470507 try :
471- _ , results = future .result ()
508+ _ , results , count , cache_entries = future .result ()
509+ total_count += count
510+
511+ test_file = future_to_file [future ]
512+ file_hash = TestsCache .compute_file_hash (str (test_file ))
513+ for cache_entry in cache_entries :
514+ tests_cache .insert_test (file_path = str (test_file ), file_hash = file_hash , ** cache_entry )
515+
472516 for qualified_name , function_called in results :
473517 function_to_test_map [qualified_name ].add (function_called )
474518 progress .advance (task_id )
@@ -477,4 +521,6 @@ def process_test_files(
477521 logger .error (f"Error processing test file { test_file } : { e } " )
478522 progress .advance (task_id )
479523
480- return {function : list (tests ) for function , tests in function_to_test_map .items ()}
524+ tests_cache .close ()
525+ function_to_tests_dict = {function : list (tests ) for function , tests in function_to_test_map .items ()}
526+ return function_to_tests_dict , total_count
0 commit comments