@@ -136,6 +136,8 @@ def __init__(self, project_root_path: Path) -> None:
136136 )
137137
138138 self .memory_cache = {}
139+ self .pending_rows : list [tuple [str , str , str , str , str , str , int | TestType , int , int ]] = []
140+ self .writes_enabled = True
139141
140142 def insert_test (
141143 self ,
@@ -150,10 +152,8 @@ def insert_test(
150152 col_number : int ,
151153 ) -> None :
152154 test_type_value = test_type .value if hasattr (test_type , "value" ) else test_type
153- self .cur .execute (
154- "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
155+ self .pending_rows .append (
155156 (
156- self .project_root_path ,
157157 file_path ,
158158 file_hash ,
159159 qualified_name_with_modules_from_root ,
@@ -163,9 +163,26 @@ def insert_test(
163163 test_type_value ,
164164 line_number ,
165165 col_number ,
166- ),
166+ )
167167 )
168- self .connection .commit ()
168+
169+ def flush (self ) -> None :
170+ if not self .pending_rows :
171+ return
172+ if not self .writes_enabled :
173+ self .pending_rows .clear ()
174+ return
175+ try :
176+ self .cur .executemany (
177+ "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
178+ [(self .project_root_path , * row ) for row in self .pending_rows ],
179+ )
180+ self .connection .commit ()
181+ except sqlite3 .OperationalError as e :
182+ logger .debug (f"Failed to persist discovered test cache, disabling cache writes: { e } " )
183+ self .writes_enabled = False
184+ finally :
185+ self .pending_rows .clear ()
169186
170187 def get_function_to_test_map_for_file (
171188 self , file_path : str , file_hash : str
@@ -212,6 +229,7 @@ def compute_file_hash(path: Path) -> str:
212229 return h .hexdigest ()
213230
214231 def close (self ) -> None :
232+ self .flush ()
215233 self .cur .close ()
216234 self .connection .close ()
217235
@@ -849,6 +867,10 @@ def process_test_files(
849867 function_to_test_map = defaultdict (set )
850868 num_discovered_tests = 0
851869 num_discovered_replay_tests = 0
870+ functions_to_optimize_by_name : dict [str , list [FunctionToOptimize ]] = defaultdict (list )
871+ if functions_to_optimize :
872+ for function_to_optimize in functions_to_optimize :
873+ functions_to_optimize_by_name [function_to_optimize .function_name ].append (function_to_optimize )
852874
853875 # Set up sys_path for Jedi to resolve imports correctly
854876 import sys
@@ -891,8 +913,8 @@ def process_test_files(
891913 test_functions = set ()
892914
893915 all_names = script .get_names (all_scopes = True , references = True )
894- all_defs = script .get_names (all_scopes = True , definitions = True )
895916 all_names_top = script .get_names (all_scopes = True )
917+ all_defs = [name for name in all_names if name .is_definition ()]
896918
897919 top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
898920 top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
@@ -967,10 +989,9 @@ def process_test_files(
967989
968990 test_function_names_set = set (test_functions_by_name .keys ())
969991 relevant_names = []
970-
971- names_with_full_name = [name for name in all_names if name .full_name is not None ]
972-
973- for name in names_with_full_name :
992+ for name in all_names :
993+ if name .full_name is None :
994+ continue
974995 match = FUNCTION_NAME_REGEX .search (name .full_name )
975996 if match and match .group (1 ) in test_function_names_set :
976997 relevant_names .append ((name , match .group (1 )))
@@ -985,56 +1006,49 @@ def process_test_files(
9851006 if not definition or definition [0 ].type != "function" :
9861007 # Fallback: Try to match against functions_to_optimize when Jedi can't resolve
9871008 # This handles cases where Jedi fails with pytest fixtures
988- if functions_to_optimize and name .name :
989- for func_to_opt in functions_to_optimize :
990- # Check if this unresolved name matches a function we're looking for
991- if func_to_opt .function_name == name .name :
992- # Check if the test file imports the class/module containing this function
993- qualified_name_with_modules = func_to_opt .qualified_name_with_modules_from_root (
994- project_root_path
995- )
1009+ if functions_to_optimize_by_name and name .name :
1010+ for func_to_opt in functions_to_optimize_by_name .get (name .name , []):
1011+ qualified_name_with_modules = func_to_opt .qualified_name_with_modules_from_root (
1012+ project_root_path
1013+ )
9961014
997- # Only add if this test actually tests the function we're optimizing
998- for test_func in test_functions_by_name [scope ]:
999- if test_func .parameters is not None :
1000- if test_framework == "pytest" :
1001- scope_test_function = (
1002- f"{ test_func .function_name } [{ test_func .parameters } ]"
1003- )
1004- else : # unittest
1005- scope_test_function = (
1006- f"{ test_func .function_name } _{ test_func .parameters } "
1007- )
1008- else :
1009- scope_test_function = test_func .function_name
1010-
1011- function_to_test_map [qualified_name_with_modules ].add (
1012- FunctionCalledInTest (
1013- tests_in_file = TestsInFile (
1014- test_file = test_file ,
1015- test_class = test_func .test_class ,
1016- test_function = scope_test_function ,
1017- test_type = test_func .test_type ,
1018- ),
1019- position = CodePosition (line_no = name .line , col_no = name .column ),
1020- )
1021- )
1022- tests_cache .insert_test (
1023- file_path = str (test_file ),
1024- file_hash = file_hash ,
1025- qualified_name_with_modules_from_root = qualified_name_with_modules ,
1026- function_name = scope ,
1027- test_class = test_func .test_class or "" ,
1028- test_function = scope_test_function ,
1029- test_type = test_func .test_type ,
1030- line_number = name .line ,
1031- col_number = name .column ,
1015+ # Only add if this test actually tests the function we're optimizing
1016+ for test_func in test_functions_by_name [scope ]:
1017+ if test_func .parameters is not None :
1018+ if test_framework == "pytest" :
1019+ scope_test_function = f"{ test_func .function_name } [{ test_func .parameters } ]"
1020+ else : # unittest
1021+ scope_test_function = f"{ test_func .function_name } _{ test_func .parameters } "
1022+ else :
1023+ scope_test_function = test_func .function_name
1024+
1025+ function_to_test_map [qualified_name_with_modules ].add (
1026+ FunctionCalledInTest (
1027+ tests_in_file = TestsInFile (
1028+ test_file = test_file ,
1029+ test_class = test_func .test_class ,
1030+ test_function = scope_test_function ,
1031+ test_type = test_func .test_type ,
1032+ ),
1033+ position = CodePosition (line_no = name .line , col_no = name .column ),
10321034 )
1035+ )
1036+ tests_cache .insert_test (
1037+ file_path = str (test_file ),
1038+ file_hash = file_hash ,
1039+ qualified_name_with_modules_from_root = qualified_name_with_modules ,
1040+ function_name = scope ,
1041+ test_class = test_func .test_class or "" ,
1042+ test_function = scope_test_function ,
1043+ test_type = test_func .test_type ,
1044+ line_number = name .line ,
1045+ col_number = name .column ,
1046+ )
10331047
1034- if test_func .test_type == TestType .REPLAY_TEST :
1035- num_discovered_replay_tests += 1
1048+ if test_func .test_type == TestType .REPLAY_TEST :
1049+ num_discovered_replay_tests += 1
10361050
1037- num_discovered_tests += 1
1051+ num_discovered_tests += 1
10381052 continue
10391053 definition_obj = definition [0 ]
10401054 definition_path = str (definition_obj .module_path )
@@ -1090,6 +1104,7 @@ def process_test_files(
10901104 logger .debug (str (e ))
10911105 continue
10921106
1107+ tests_cache .flush ()
10931108 progress .advance (task_id )
10941109
10951110 tests_cache .close ()
0 commit comments