11from __future__ import annotations
22
3+ import hashlib
34import os
45import pickle
56import re
7+ import sqlite3
68import subprocess
79import unittest
810from collections import defaultdict
1517
1618from codeflash .cli_cmds .console import console , logger , test_files_progress_bar
1719from codeflash .code_utils .code_utils import get_run_tmp_file , module_name_from_file_path
18- from codeflash .code_utils .compat import SAFE_SYS_EXECUTABLE
20+ from codeflash .code_utils .compat import SAFE_SYS_EXECUTABLE , codeflash_cache_dir
1921from codeflash .models .models import CodePosition , FunctionCalledInTest , TestsInFile , TestType
2022
2123if TYPE_CHECKING :
@@ -37,13 +39,100 @@ class TestFunction:
3739FUNCTION_NAME_REGEX = re .compile (r"([^.]+)\.([a-zA-Z0-9_]+)$" )
3840
3941
42+ class TestsCache :
43+ def __init__ (self ) -> None :
44+ self .connection = sqlite3 .connect (codeflash_cache_dir / "tests_cache.db" )
45+ self .cur = self .connection .cursor ()
46+
47+ self .cur .execute (
48+ """
49+ CREATE TABLE IF NOT EXISTS discovered_tests(
50+ file_path TEXT,
51+ file_hash TEXT,
52+ qualified_name_with_modules_from_root TEXT,
53+ function_name TEXT,
54+ test_class TEXT,
55+ test_function TEXT,
56+ test_type TEXT,
57+ line_number INTEGER,
58+ col_number INTEGER
59+ )
60+ """
61+ )
62+ self .cur .execute (
63+ """
64+ CREATE INDEX IF NOT EXISTS idx_discovered_tests_file_path_hash
65+ ON discovered_tests (file_path, file_hash)
66+ """
67+ )
68+ self ._memory_cache = {}
69+
70+ def insert_test (
71+ self ,
72+ file_path : str ,
73+ file_hash : str ,
74+ qualified_name_with_modules_from_root : str ,
75+ function_name : str ,
76+ test_class : str ,
77+ test_function : str ,
78+ test_type : TestType ,
79+ line_number : int ,
80+ col_number : int ,
81+ ) -> None :
82+ test_type_value = test_type .value if hasattr (test_type , "value" ) else test_type
83+ self .cur .execute (
84+ "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
85+ (
86+ file_path ,
87+ file_hash ,
88+ qualified_name_with_modules_from_root ,
89+ function_name ,
90+ test_class ,
91+ test_function ,
92+ test_type_value ,
93+ line_number ,
94+ col_number ,
95+ ),
96+ )
97+ self .connection .commit ()
98+
99+ def get_tests_for_file (self , file_path : str , file_hash : str ) -> list [FunctionCalledInTest ]:
100+ cache_key = (file_path , file_hash )
101+ if cache_key in self ._memory_cache :
102+ return self ._memory_cache [cache_key ]
103+ self .cur .execute ("SELECT * FROM discovered_tests WHERE file_path = ? AND file_hash = ?" , (file_path , file_hash ))
104+ result = [
105+ FunctionCalledInTest (
106+ tests_in_file = TestsInFile (
107+ test_file = Path (row [0 ]), test_class = row [4 ], test_function = row [5 ], test_type = TestType (int (row [6 ]))
108+ ),
109+ position = CodePosition (line_no = row [7 ], col_no = row [8 ]),
110+ )
111+ for row in self .cur .fetchall ()
112+ ]
113+ self ._memory_cache [cache_key ] = result
114+ return result
115+
116+ @staticmethod
117+ def compute_file_hash (path : str ) -> str :
118+ h = hashlib .md5 (usedforsecurity = False )
119+ with Path (path ).open ("rb" ) as f :
120+ while True :
121+ chunk = f .read (8192 )
122+ if not chunk :
123+ break
124+ h .update (chunk )
125+ return h .hexdigest ()
126+
127+ def close (self ) -> None :
128+ self .cur .close ()
129+ self .connection .close ()
130+
131+
40132def discover_unit_tests (
41133 cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
42134) -> dict [str , list [FunctionCalledInTest ]]:
43- framework_strategies : dict [str , Callable ] = {
44- "pytest" : discover_tests_pytest ,
45- "unittest" : discover_tests_unittest ,
46- }
135+ framework_strategies : dict [str , Callable ] = {"pytest" : discover_tests_pytest , "unittest" : discover_tests_unittest }
47136 strategy = framework_strategies .get (cfg .test_framework , None )
48137 if not strategy :
49138 error_message = f"Unsupported test framework: { cfg .test_framework } "
@@ -54,7 +143,7 @@ def discover_unit_tests(
54143
55144def discover_tests_pytest (
56145 cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
57- ) -> dict [str , list [FunctionCalledInTest ]]:
146+ ) -> dict [Path , list [FunctionCalledInTest ]]:
58147 tests_root = cfg .tests_root
59148 project_root = cfg .project_root_path
60149
@@ -91,17 +180,15 @@ def discover_tests_pytest(
91180 )
92181
93182 elif 0 <= exitcode <= 5 :
94- logger .warning (
95- f"Failed to collect tests. Pytest Exit code: { exitcode } ={ ExitCode (exitcode ).name } "
96- )
183+ logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } ={ ExitCode (exitcode ).name } " )
97184 else :
98185 logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } " )
99186 console .rule ()
100187 else :
101188 logger .debug (f"Pytest collection exit code: { exitcode } " )
102189 if pytest_rootdir is not None :
103190 cfg .tests_project_rootdir = Path (pytest_rootdir )
104- file_to_test_map = defaultdict (list )
191+ file_to_test_map : dict [ Path , list [ FunctionCalledInTest ]] = defaultdict (list )
105192 for test in tests :
106193 if "__replay_test" in test ["test_file" ]:
107194 test_type = TestType .REPLAY_TEST
@@ -116,10 +203,7 @@ def discover_tests_pytest(
116203 test_function = test ["test_function" ],
117204 test_type = test_type ,
118205 )
119- if (
120- discover_only_these_tests
121- and test_obj .test_file not in discover_only_these_tests
122- ):
206+ if discover_only_these_tests and test_obj .test_file not in discover_only_these_tests :
123207 continue
124208 file_to_test_map [test_obj .test_file ].append (test_obj )
125209 # Within these test files, find the project functions they are referring to and return their names/locations
@@ -128,7 +212,7 @@ def discover_tests_pytest(
128212
129213def discover_tests_unittest (
130214 cfg : TestConfig , discover_only_these_tests : list [str ] | None = None
131- ) -> dict [str , list [FunctionCalledInTest ]]:
215+ ) -> dict [Path , list [FunctionCalledInTest ]]:
132216 tests_root : Path = cfg .tests_root
133217 loader : unittest .TestLoader = unittest .TestLoader ()
134218 tests : unittest .TestSuite = loader .discover (str (tests_root ))
@@ -144,8 +228,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
144228 _test_module_path = Path (_test_module .replace ("." , os .sep )).with_suffix (".py" )
145229 _test_module_path = tests_root / _test_module_path
146230 if not _test_module_path .exists () or (
147- discover_only_these_tests
148- and str (_test_module_path ) not in discover_only_these_tests
231+ discover_only_these_tests and str (_test_module_path ) not in discover_only_these_tests
149232 ):
150233 return None
151234 if "__replay_test" in str (_test_module_path ):
@@ -172,9 +255,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
172255 if not hasattr (test , "_testMethodName" ) and hasattr (test , "_tests" ):
173256 for test_2 in test ._tests :
174257 if not hasattr (test_2 , "_testMethodName" ):
175- logger .warning (
176- f"Didn't find tests for { test_2 } "
177- ) # it goes deeper?
258+ logger .warning (f"Didn't find tests for { test_2 } " ) # it goes deeper?
178259 continue
179260 details = get_test_details (test_2 )
180261 if details is not None :
@@ -195,19 +276,35 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
195276
196277
197278def process_test_files (
198- file_to_test_map : dict [str , list [TestsInFile ]], cfg : TestConfig
279+ file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
199280) -> dict [str , list [FunctionCalledInTest ]]:
200281 project_root_path = cfg .project_root_path
201282 test_framework = cfg .test_framework
283+
202284 function_to_test_map = defaultdict (set )
203285 jedi_project = jedi .Project (path = project_root_path )
204286 goto_cache = {}
287+ tests_cache = TestsCache ()
205288
206- with test_files_progress_bar (
207- total = len ( file_to_test_map ), description = "Processing test files"
208- ) as ( progress , task_id ):
209-
289+ with test_files_progress_bar (total = len ( file_to_test_map ), description = "Processing test files" ) as (
290+ progress ,
291+ task_id ,
292+ ):
210293 for test_file , functions in file_to_test_map .items ():
294+ file_hash = TestsCache .compute_file_hash (test_file )
295+ cached_tests = tests_cache .get_tests_for_file (str (test_file ), file_hash )
296+ if cached_tests :
297+ self_cur = tests_cache .cur
298+ self_cur .execute (
299+ "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?" ,
300+ (str (test_file ), file_hash ),
301+ )
302+ qualified_names = [row [0 ] for row in self_cur .fetchall ()]
303+ for cached , qualified_name in zip (cached_tests , qualified_names ):
304+ function_to_test_map [qualified_name ].add (cached )
305+ progress .advance (task_id )
306+ continue
307+
211308 try :
212309 script = jedi .Script (path = test_file , project = jedi_project )
213310 test_functions = set ()
@@ -216,12 +313,8 @@ def process_test_files(
216313 all_defs = script .get_names (all_scopes = True , definitions = True )
217314 all_names_top = script .get_names (all_scopes = True )
218315
219- top_level_functions = {
220- name .name : name for name in all_names_top if name .type == "function"
221- }
222- top_level_classes = {
223- name .name : name for name in all_names_top if name .type == "class"
224- }
316+ top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
317+ top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
225318 except Exception as e :
226319 logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
227320 progress .advance (task_id )
@@ -230,36 +323,18 @@ def process_test_files(
230323 if test_framework == "pytest" :
231324 for function in functions :
232325 if "[" in function .test_function :
233- function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (
234- function .test_function
235- )[0 ]
236- parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (
237- function .test_function
238- )[1 ]
326+ function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[0 ]
327+ parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX .split (function .test_function )[1 ]
239328 if function_name in top_level_functions :
240329 test_functions .add (
241- TestFunction (
242- function_name ,
243- function .test_class ,
244- parameters ,
245- function .test_type ,
246- )
330+ TestFunction (function_name , function .test_class , parameters , function .test_type )
247331 )
248332 elif function .test_function in top_level_functions :
249333 test_functions .add (
250- TestFunction (
251- function .test_function ,
252- function .test_class ,
253- None ,
254- function .test_type ,
255- )
256- )
257- elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (
258- function .test_function
259- ):
260- base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub (
261- "" , function .test_function
334+ TestFunction (function .test_function , function .test_class , None , function .test_type )
262335 )
336+ elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX .match (function .test_function ):
337+ base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX .sub ("" , function .test_function )
263338 if base_name in top_level_functions :
264339 test_functions .add (
265340 TestFunction (
@@ -283,9 +358,7 @@ def process_test_files(
283358 and f".{ matched_name } ." in def_name .full_name
284359 ):
285360 for function in functions_to_search :
286- (is_parameterized , new_function , parameters ) = (
287- discover_parameters_unittest (function )
288- )
361+ (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
289362
290363 if is_parameterized and new_function == def_name .name :
291364 test_functions .add (
@@ -329,9 +402,7 @@ def process_test_files(
329402 if cache_key in goto_cache :
330403 definition = goto_cache [cache_key ]
331404 else :
332- definition = name .goto (
333- follow_imports = True , follow_builtin_imports = False
334- )
405+ definition = name .goto (follow_imports = True , follow_builtin_imports = False )
335406 goto_cache [cache_key ] = definition
336407 except Exception as e :
337408 logger .debug (str (e ))
@@ -358,11 +429,23 @@ def process_test_files(
358429 if test_framework == "unittest" :
359430 scope_test_function += "_" + scope_parameters
360431
361- full_name_without_module_prefix = definition [
362- 0
363- ]. full_name . replace ( definition [ 0 ]. module_name + "." , "" , 1 )
432+ full_name_without_module_prefix = definition [0 ]. full_name . replace (
433+ definition [ 0 ]. module_name + "." , "" , 1
434+ )
364435 qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
365436
437+ tests_cache .insert_test (
438+ file_path = str (test_file ),
439+ file_hash = file_hash ,
440+ qualified_name_with_modules_from_root = qualified_name_with_modules_from_root ,
441+ function_name = scope ,
442+ test_class = scope_test_class ,
443+ test_function = scope_test_function ,
444+ test_type = test_type ,
445+ line_number = name .line ,
446+ col_number = name .column ,
447+ )
448+
366449 function_to_test_map [qualified_name_with_modules_from_root ].add (
367450 FunctionCalledInTest (
368451 tests_in_file = TestsInFile (
@@ -371,12 +454,11 @@ def process_test_files(
371454 test_function = scope_test_function ,
372455 test_type = test_type ,
373456 ),
374- position = CodePosition (
375- line_no = name .line , col_no = name .column
376- ),
457+ position = CodePosition (line_no = name .line , col_no = name .column ),
377458 )
378459 )
379460
380461 progress .advance (task_id )
381462
463+ tests_cache .close ()
382464 return {function : list (tests ) for function , tests in function_to_test_map .items ()}
0 commit comments