1111#
1212from __future__ import annotations
1313
14+ import contextlib
1415import importlib .machinery
1516import io
1617import json
4243from codeflash .tracing .replay_test import create_trace_replay_test
4344from codeflash .tracing .tracing_utils import FunctionModules
4445from codeflash .verification .verification_utils import get_test_file_path
45- import contextlib
4646
4747if TYPE_CHECKING :
4848 from types import FrameType , TracebackType
@@ -100,6 +100,7 @@ def __init__(
100100 )
101101 disable = True
102102 self .disable = disable
103+ self ._db_lock : threading .Lock | None = None
103104 if self .disable :
104105 return
105106 if sys .getprofile () is not None or sys .gettrace () is not None :
@@ -109,6 +110,9 @@ def __init__(
109110 )
110111 self .disable = True
111112 return
113+
114+ self ._db_lock = threading .Lock ()
115+
112116 self .con = None
113117 self .output_file = Path (output ).resolve ()
114118 self .functions = functions
@@ -180,34 +184,55 @@ def __enter__(self) -> None:
180184 def __exit__ (
181185 self , exc_type : type [BaseException ] | None , exc_val : BaseException | None , exc_tb : TracebackType | None
182186 ) -> None :
183- if self .disable :
187+ if self .disable or self . _db_lock is None :
184188 return
185189 sys .setprofile (None )
186- self .con .commit ()
187- console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
188- self .create_stats ()
190+ threading .setprofile (None )
189191
190- cur = self .con .cursor ()
191- cur .execute (
192- "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
193- "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
194- "cumulative_time_ns INTEGER, callers BLOB)"
195- )
196- for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
197- remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
192+ with self ._db_lock :
193+ if self .con is None :
194+ return
195+
196+ self .con .commit () # Commit any pending from tracer_logic
197+ console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
198+ self .create_stats () # This calls snapshot_stats which uses self.timings
199+
200+ cur = self .con .cursor ()
198201 cur .execute (
199- "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
200- (str (Path (func [0 ]).resolve ()), func [1 ], func [2 ], func [3 ], cc , nc , tt , ct , json .dumps (remapped_callers )),
202+ "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
203+ "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
204+ "cumulative_time_ns INTEGER, callers BLOB)"
201205 )
202- self .con .commit ()
206+ # self.stats is populated by snapshot_stats() called within create_stats()
207+ # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
208+ # For now, assuming self.stats is primarily in-memory after create_stats()
209+ for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
210+ remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
211+ cur .execute (
212+ "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
213+ (
214+ str (Path (func [0 ]).resolve ()),
215+ func [1 ],
216+ func [2 ],
217+ func [3 ],
218+ cc ,
219+ nc ,
220+ tt ,
221+ ct ,
222+ json .dumps (remapped_callers ),
223+ ),
224+ )
225+ self .con .commit ()
203226
204- self .make_pstats_compatible ()
205- self .print_stats ("tottime" )
206- cur = self .con .cursor ()
207- cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
208- cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
209- self .con .commit ()
210- self .con .close ()
227+ self .make_pstats_compatible () # Modifies self.stats and self.timings in-memory
228+ self .print_stats ("tottime" ) # Uses self.stats, prints to console
229+
230+ cur = self .con .cursor () # New cursor
231+ cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
232+ cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
233+ self .con .commit ()
234+ self .con .close ()
235+ self .con = None # Mark connection as closed
211236
212237 # filter any functions where we did not capture the return
213238 self .function_modules = [
@@ -252,6 +277,9 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
252277 threading .setprofile (None )
253278 console .print (f"Codeflash: Timeout reached! Stopping tracing at { self .timeout } seconds." )
254279 return
280+ if self .disable or self ._db_lock is None or self .con is None :
281+ return
282+
255283 code = frame .f_code
256284
257285 # Check function name first before resolving path
@@ -331,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
331359
332360 # TODO: Also check if this function arguments are unique from the values logged earlier
333361
334- cur = self .con .cursor ()
362+ with self ._db_lock :
363+ # Check connection again inside lock, in case __exit__ closed it.
364+ if self .con is None :
365+ return
335366
336- t_ns = time .perf_counter_ns ()
337- original_recursion_limit = sys .getrecursionlimit ()
338- try :
339- # pickling can be a recursive operator, so we need to increase the recursion limit
340- sys .setrecursionlimit (10000 )
341- # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
342- # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
343- # leaks, bad references or side effects when unpickling.
344- arguments = dict (arguments .items ())
345- if class_name and code .co_name == "__init__" :
346- del arguments ["self" ]
347- local_vars = pickle .dumps (arguments , protocol = pickle .HIGHEST_PROTOCOL )
348- sys .setrecursionlimit (original_recursion_limit )
349- except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
350- # we retry with dill if pickle fails. It's slower but more comprehensive
367+ cur = self .con .cursor ()
368+
369+ t_ns = time .perf_counter_ns ()
370+ original_recursion_limit = sys .getrecursionlimit ()
351371 try :
352- local_vars = dill .dumps (arguments , protocol = dill .HIGHEST_PROTOCOL )
372+ # pickling can be a recursive operator, so we need to increase the recursion limit
373+ sys .setrecursionlimit (10000 )
374+ # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
375+ # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
376+ # leaks, bad references or side effects when unpickling.
377+ arguments_copy = dict (arguments .items ()) # Use the local 'arguments' from frame.f_locals
378+ if class_name and code .co_name == "__init__" and "self" in arguments_copy :
379+ del arguments_copy ["self" ]
380+ local_vars = pickle .dumps (arguments_copy , protocol = pickle .HIGHEST_PROTOCOL )
353381 sys .setrecursionlimit (original_recursion_limit )
382+ except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
383+ # we retry with dill if pickle fails. It's slower but more comprehensive
384+ try :
385+ sys .setrecursionlimit (10000 ) # Ensure limit is high for dill too
386+ # arguments_copy should be used here as well if defined above
387+ local_vars = dill .dumps (
388+ arguments_copy if "arguments_copy" in locals () else dict (arguments .items ()),
389+ protocol = dill .HIGHEST_PROTOCOL ,
390+ )
391+ sys .setrecursionlimit (original_recursion_limit )
392+
393+ except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
394+ self .function_count [function_qualified_name ] -= 1
395+ return
354396
355- except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
356- # give up
357- self .function_count [function_qualified_name ] -= 1
358- return
359- cur .execute (
360- "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
361- (
362- event ,
363- code .co_name ,
364- class_name ,
365- str (file_name ),
366- frame .f_lineno ,
367- frame .f_back .__hash__ (),
368- t_ns ,
369- local_vars ,
370- ),
371- )
372- self .trace_count += 1
373- self .next_insert -= 1
374- if self .next_insert == 0 :
375- self .next_insert = 1000
376- self .con .commit ()
397+ cur .execute (
398+ "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
399+ (
400+ event ,
401+ code .co_name ,
402+ class_name ,
403+ str (file_name ),
404+ frame .f_lineno ,
405+ frame .f_back .__hash__ (),
406+ t_ns ,
407+ local_vars ,
408+ ),
409+ )
410+ self .trace_count += 1
411+ self .next_insert -= 1
412+ if self .next_insert == 0 :
413+ self .next_insert = 1000
414+ self .con .commit ()
377415
378416 def trace_callback (self , frame : FrameType , event : str , arg : str | None ) -> None :
379417 # profiler section
0 commit comments