11from __future__ import annotations
22
3+ import hashlib
34import os
45import pickle
56import re
7+ import sqlite3
68import subprocess
79import unittest
810from collections import defaultdict
1517
1618from codeflash .cli_cmds .console import console , logger , test_files_progress_bar
1719from codeflash .code_utils .code_utils import get_run_tmp_file , module_name_from_file_path
18- from codeflash .code_utils .compat import SAFE_SYS_EXECUTABLE
20+ from codeflash .code_utils .compat import SAFE_SYS_EXECUTABLE , codeflash_cache_db
1921from codeflash .models .models import CodePosition , FunctionCalledInTest , TestsInFile , TestType
2022
2123if TYPE_CHECKING :
@@ -37,13 +39,101 @@ class TestFunction:
3739FUNCTION_NAME_REGEX = re .compile (r"([^.]+)\.([a-zA-Z0-9_]+)$" )
3840
3941
42+ class TestsCache :
43+ def __init__ (self ) -> None :
44+ self .connection = sqlite3 .connect (codeflash_cache_db )
45+ self .cur = self .connection .cursor ()
46+
47+ self .cur .execute (
48+ """
49+ CREATE TABLE IF NOT EXISTS discovered_tests(
50+ file_path TEXT,
51+ file_hash TEXT,
52+ qualified_name_with_modules_from_root TEXT,
53+ function_name TEXT,
54+ test_class TEXT,
55+ test_function TEXT,
56+ test_type TEXT,
57+ line_number INTEGER,
58+ col_number INTEGER
59+ )
60+ """
61+ )
62+ self .cur .execute (
63+ """
64+ CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
65+ ON discovered_tests (file_path, file_hash)
66+ """
67+ )
68+ self ._memory_cache = {}
69+
70+ def insert_test (
71+ self ,
72+ file_path : str ,
73+ file_hash : str ,
74+ qualified_name_with_modules_from_root : str ,
75+ function_name : str ,
76+ test_class : str ,
77+ test_function : str ,
78+ test_type : TestType ,
79+ line_number : int ,
80+ col_number : int ,
81+ ) -> None :
82+ self .cur .execute ("DELETE FROM discovered_tests WHERE file_path = ?" , (file_path ,))
83+ test_type_value = test_type .value if hasattr (test_type , "value" ) else test_type
84+ self .cur .execute (
85+ "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
86+ (
87+ file_path ,
88+ file_hash ,
89+ qualified_name_with_modules_from_root ,
90+ function_name ,
91+ test_class ,
92+ test_function ,
93+ test_type_value ,
94+ line_number ,
95+ col_number ,
96+ ),
97+ )
98+ self .connection .commit ()
99+
100+ def get_tests_for_file (self , file_path : str , file_hash : str ) -> list [FunctionCalledInTest ]:
101+ cache_key = (file_path , file_hash )
102+ if cache_key in self ._memory_cache :
103+ return self ._memory_cache [cache_key ]
104+ self .cur .execute ("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?" , (file_path , file_hash ))
105+ result = [
106+ FunctionCalledInTest (
107+ tests_in_file = TestsInFile (
108+ test_file = Path (row [0 ]), test_class = row [4 ], test_function = row [5 ], test_type = TestType (int (row [6 ]))
109+ ),
110+ position = CodePosition (line_no = row [7 ], col_no = row [8 ]),
111+ )
112+ for row in self .cur .fetchall ()
113+ ]
114+ self ._memory_cache [cache_key ] = result
115+ return result
116+
117+ @staticmethod
118+ def compute_file_hash (path : str ) -> str :
119+ h = hashlib .sha256 (usedforsecurity = False )
120+ with Path (path ).open ("rb" ) as f :
121+ while True :
122+ chunk = f .read (8192 )
123+ if not chunk :
124+ break
125+ h .update (chunk )
126+ return h .hexdigest ()
127+
128+ def close (self ) -> None :
129+ self .cur .close ()
130+ self .connection .close ()
131+
132+
40133def discover_unit_tests (
41134 cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
42135) -> dict [str , list [FunctionCalledInTest ]]:
43- framework_strategies : dict [str , Callable ] = {
44- "pytest" : discover_tests_pytest ,
45- "unittest" : discover_tests_unittest ,
46- }
136+ framework_strategies : dict [str , Callable ] = {"pytest" : discover_tests_pytest , "unittest" : discover_tests_unittest }
47137 strategy = framework_strategies .get (cfg .test_framework , None )
48138 if not strategy :
49139 error_message = f"Unsupported test framework: { cfg .test_framework } "
@@ -54,7 +144,7 @@ def discover_unit_tests(
54144
55145def discover_tests_pytest (
56146 cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
57- ) -> dict [str , list [FunctionCalledInTest ]]:
147+ ) -> dict [Path , list [FunctionCalledInTest ]]:
58148 tests_root = cfg .tests_root
59149 project_root = cfg .project_root_path
60150
@@ -91,17 +181,15 @@ def discover_tests_pytest(
91181 )
92182
93183 elif 0 <= exitcode <= 5 :
94- logger .warning (
95- f"Failed to collect tests. Pytest Exit code: { exitcode } ={ ExitCode (exitcode ).name } "
96- )
184+ logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } ={ ExitCode (exitcode ).name } " )
97185 else :
98186 logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } " )
99187 console .rule ()
100188 else :
101189 logger .debug (f"Pytest collection exit code: { exitcode } " )
102190 if pytest_rootdir is not None :
103191 cfg .tests_project_rootdir = Path (pytest_rootdir )
104- file_to_test_map = defaultdict (list )
192+ file_to_test_map : dict [ Path , list [ FunctionCalledInTest ]] = defaultdict (list )
105193 for test in tests :
106194 if "__replay_test" in test ["test_file" ]:
107195 test_type = TestType .REPLAY_TEST
@@ -116,10 +204,7 @@ def discover_tests_pytest(
116204 test_function = test ["test_function" ],
117205 test_type = test_type ,
118206 )
119- if (
120- discover_only_these_tests
121- and test_obj .test_file not in discover_only_these_tests
122- ):
207+ if discover_only_these_tests and test_obj .test_file not in discover_only_these_tests :
123208 continue
124209 file_to_test_map [test_obj .test_file ].append (test_obj )
125210 # Within these test files, find the project functions they are referring to and return their names/locations
@@ -128,7 +213,7 @@ def discover_tests_pytest(
128213
129214def discover_tests_unittest (
130215 cfg : TestConfig , discover_only_these_tests : list [str ] | None = None
131- ) -> dict [str , list [FunctionCalledInTest ]]:
216+ ) -> dict [Path , list [FunctionCalledInTest ]]:
132217 tests_root : Path = cfg .tests_root
133218 loader : unittest .TestLoader = unittest .TestLoader ()
134219 tests : unittest .TestSuite = loader .discover (str (tests_root ))
@@ -144,8 +229,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
144229 _test_module_path = Path (_test_module .replace ("." , os .sep )).with_suffix (".py" )
145230 _test_module_path = tests_root / _test_module_path
146231 if not _test_module_path .exists () or (
147- discover_only_these_tests
148- and str (_test_module_path ) not in discover_only_these_tests
232+ discover_only_these_tests and str (_test_module_path ) not in discover_only_these_tests
149233 ):
150234 return None
151235 if "__replay_test" in str (_test_module_path ):
@@ -172,9 +256,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
172256 if not hasattr (test , "_testMethodName" ) and hasattr (test , "_tests" ):
173257 for test_2 in test ._tests :
174258 if not hasattr (test_2 , "_testMethodName" ):
175- logger .warning (
176- f"Didn't find tests for { test_2 } "
177- ) # it goes deeper?
259+ logger .warning (f"Didn't find tests for { test_2 } " ) # it goes deeper?
178260 continue
179261 details = get_test_details (test_2 )
180262 if details is not None :
@@ -195,19 +277,35 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
195277
196278
197279def process_test_files (
198- file_to_test_map : dict [str , list [TestsInFile ]], cfg : TestConfig
280+ file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
199281) -> dict [str , list [FunctionCalledInTest ]]:
200282 project_root_path = cfg .project_root_path
201283 test_framework = cfg .test_framework
284+
202285 function_to_test_map = defaultdict (set )
203286 jedi_project = jedi .Project (path = project_root_path )
204287 goto_cache = {}
288+ tests_cache = TestsCache ()
205289
206- with test_files_progress_bar (
207- total = len ( file_to_test_map ), description = "Processing test files"
208- ) as ( progress , task_id ):
209-
290+ with test_files_progress_bar (total = len ( file_to_test_map ), description = "Processing test files" ) as (
291+ progress ,
292+ task_id ,
293+ ):
210294 for test_file , functions in file_to_test_map .items ():
295+ file_hash = TestsCache .compute_file_hash (test_file )
296+ cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
297+ if cached_tests :
298+ self_cur = tests_cache .cur
299+ self_cur .execute (
300+ "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
301+ (str (test_file ), file_hash ),
302+ )
303+ qualified_names = [row [0 ] for row in self_cur .fetchall ()]
304+ for cached , qualified_name in zip (cached_tests , qualified_names ):
305+ function_to_test_map [qualified_name ].add (cached )
306+ progress .advance (task_id )
307+ continue
308+
211309 try :
212310 script = jedi .Script (path = test_file , project = jedi_project )
213311 test_functions = set ()
@@ -216,12 +314,8 @@ def process_test_files(
216314 all_defs = script .get_names (all_scopes = True , definitions = True )
217315 all_names_top = script .get_names (all_scopes = True )
218316
219- top_level_functions = {
220- name .name : name for name in all_names_top if name .type == "function"
221- }
222- top_level_classes = {
223- name .name : name for name in all_names_top if name .type == "class"
224- }
317+ top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
318+ top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
225319 except Exception as e :
226320 logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
227321 progress .advance (task_id )
@@ -230,36 +324,18 @@ def process_test_files(
230324 if test_framework == "pytest" :
231325 for function in functions :
232326 if "[" in function .test_function :
233- function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (
234- function .test_function
235- )[0 ]
236- parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (
237- function .test_function
238- )[1 ]
327+ function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[0 ]
328+ parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[1 ]
239329 if function_name in top_level_functions :
240330 test_functions .add (
241- TestFunction (
242- function_name ,
243- function .test_class ,
244- parameters ,
245- function .test_type ,
246- )
331+ TestFunction (function_name , function .test_class , parameters , function .test_type )
247332 )
248333 elif function .test_function in top_level_functions :
249334 test_functions .add (
250- TestFunction (
251- function .test_function ,
252- function .test_class ,
253- None ,
254- function .test_type ,
255- )
256- )
257- elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (
258- function .test_function
259- ):
260- base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub (
261- "" , function .test_function
335+ TestFunction (function .test_function , function .test_class , None , function .test_type )
262336 )
337+ elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (function .test_function ):
338+ base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub ("" , function .test_function )
263339 if base_name in top_level_functions :
264340 test_functions .add (
265341 TestFunction (
@@ -283,9 +359,7 @@ def process_test_files(
283359 and f".{ matched_name } ." in def_name .full_name
284360 ):
285361 for function in functions_to_search :
286- (is_parameterized , new_function , parameters ) = (
287- discover_parameters_unittest (function )
288- )
362+ (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
289363
290364 if is_parameterized and new_function == def_name .name :
291365 test_functions .add (
@@ -329,9 +403,7 @@ def process_test_files(
329403 if cache_key in goto_cache :
330404 definition = goto_cache [cache_key ]
331405 else :
332- definition = name .goto (
333- follow_imports = True , follow_builtin_imports = False
334- )
406+ definition = name .goto (follow_imports = True , follow_builtin_imports = False )
335407 goto_cache [cache_key ] = definition
336408 except Exception as e :
337409 logger .debug (str (e ))
@@ -358,11 +430,23 @@ def process_test_files(
358430 if test_framework == "unittest" :
359431 scope_test_function += "_" + scope_parameters
360432
361- full_name_without_module_prefix = definition [
362- 0
363- ]. full_name . replace ( definition [ 0 ]. module_name + "." , "" , 1 )
433+ full_name_without_module_prefix = definition [0 ]. full_name . replace (
434+ definition [ 0 ]. module_name + "." , "" , 1
435+ )
364436 qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
365437
438+ tests_cache .insert_test (
439+ file_path = str (test_file ),
440+ file_hash = file_hash ,
441+ qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
442+ function_name = scope ,
443+ test_class = scope_test_class ,
444+ test_function = scope_test_function ,
445+ test_type = test_type ,
446+ line_number = name .line ,
447+ col_number = name .column ,
448+ )
449+
366450 function_to_test_map [qualified_name_with_modules_from_root ].add (
367451 FunctionCalledInTest (
368452 tests_in_file = TestsInFile (
@@ -371,12 +455,11 @@ def process_test_files(
371455 test_function = scope_test_function ,
372456 test_type = test_type ,
373457 ),
374- position = CodePosition (
375- line_no = name .line , col_no = name .column
376- ),
458+ position = CodePosition (line_no = name .line , col_no = name .column ),
377459 )
378460 )
379461
380462 progress .advance (task_id )
381463
464+ tests_cache .close ()
382465 return {function : list (tests ) for function , tests in function_to_test_map .items ()}
0 commit comments