1111#
1212from __future__ import annotations
1313
14+ import contextlib
1415import importlib .machinery
1516import io
1617import json
4243from codeflash .tracing .replay_test import create_trace_replay_test
4344from codeflash .tracing .tracing_utils import FunctionModules
4445from codeflash .verification .verification_utils import get_test_file_path
45- import contextlib
4646
4747if TYPE_CHECKING :
4848 from types import FrameType , TracebackType
@@ -77,7 +77,7 @@ def __init__(
7777 self ,
7878 output : str = "codeflash.trace" ,
7979 functions : list [str ] | None = None ,
80- disable : bool = False , # noqa: FBT001, FBT002
80+ disable : bool = False ,
8181 config_file_path : Path | None = None ,
8282 max_function_count : int = 256 ,
8383 timeout : int | None = None , # seconds
@@ -100,6 +100,7 @@ def __init__(
100100 )
101101 disable = True
102102 self .disable = disable
103+ self ._db_lock : threading .Lock | None = None
103104 if self .disable :
104105 return
105106 if sys .getprofile () is not None or sys .gettrace () is not None :
@@ -109,6 +110,9 @@ def __init__(
109110 )
110111 self .disable = True
111112 return
113+
114+ self ._db_lock = threading .Lock ()
115+
112116 self .con = None
113117 self .output_file = Path (output ).resolve ()
114118 self .functions = functions
@@ -180,34 +184,55 @@ def __enter__(self) -> None:
180184 def __exit__ (
181185 self , exc_type : type [BaseException ] | None , exc_val : BaseException | None , exc_tb : TracebackType | None
182186 ) -> None :
183- if self .disable :
187+ if self .disable or self . _db_lock is None :
184188 return
185189 sys .setprofile (None )
186- self .con .commit ()
187- console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
188- self .create_stats ()
190+ threading .setprofile (None )
189191
190- cur = self .con .cursor ()
191- cur .execute (
192- "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
193- "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
194- "cumulative_time_ns INTEGER, callers BLOB)"
195- )
196- for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
197- remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
192+ with self ._db_lock :
193+ if self .con is None :
194+ return
195+
196+ self .con .commit () # Commit any pending from tracer_logic
197+ console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
198+ self .create_stats () # This calls snapshot_stats which uses self.timings
199+
200+ cur = self .con .cursor ()
198201 cur .execute (
199- "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
200- (str (Path (func [0 ]).resolve ()), func [1 ], func [2 ], func [3 ], cc , nc , tt , ct , json .dumps (remapped_callers )),
202+ "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
203+ "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
204+ "cumulative_time_ns INTEGER, callers BLOB)"
201205 )
202- self .con .commit ()
206+ # self.stats is populated by snapshot_stats() called within create_stats()
207+ # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
208+ # For now, assuming self.stats is primarily in-memory after create_stats()
209+ for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
210+ remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
211+ cur .execute (
212+ "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
213+ (
214+ str (Path (func [0 ]).resolve ()),
215+ func [1 ],
216+ func [2 ],
217+ func [3 ],
218+ cc ,
219+ nc ,
220+ tt ,
221+ ct ,
222+ json .dumps (remapped_callers ),
223+ ),
224+ )
225+ self .con .commit ()
203226
204- self .make_pstats_compatible ()
205- self .print_stats ("tottime" )
206- cur = self .con .cursor ()
207- cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
208- cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
209- self .con .commit ()
210- self .con .close ()
227+ self .make_pstats_compatible () # Modifies self.stats and self.timings in-memory
228+ self .print_stats ("tottime" ) # Uses self.stats, prints to console
229+
230+ cur = self .con .cursor () # New cursor
231+ cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
232+ cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
233+ self .con .commit ()
234+ self .con .close ()
235+ self .con = None # Mark connection as closed
211236
212237 # filter any functions where we did not capture the return
213238 self .function_modules = [
@@ -244,16 +269,24 @@ def __exit__(
244269 overflow = "ignore" ,
245270 )
246271
247- def tracer_logic (self , frame : FrameType , event : str ) -> None : # noqa: PLR0911
272+ def tracer_logic (self , frame : FrameType , event : str ) -> None :
248273 if event != "call" :
249274 return
250275 if None is not self .timeout and (time .time () - self .start_time ) > self .timeout :
251276 sys .setprofile (None )
252277 threading .setprofile (None )
253278 console .print (f"Codeflash: Timeout reached! Stopping tracing at { self .timeout } seconds." )
254279 return
280+ if self .disable or self ._db_lock is None or self .con is None :
281+ return
282+
255283 code = frame .f_code
256284
285+ # Check function name first before resolving path
286+ if code .co_name in self .ignored_functions :
287+ return
288+
289+ # Now resolve file path only if we need it
257290 co_filename = code .co_filename
258291 if co_filename in self .path_cache :
259292 file_name = self .path_cache [co_filename ]
@@ -262,8 +295,6 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
262295 self .path_cache [co_filename ] = file_name
263296 # TODO : It currently doesn't log the last return call from the first function
264297
265- if code .co_name in self .ignored_functions :
266- return
267298 if not file_name .is_relative_to (self .project_root ):
268299 return
269300 if not file_name .exists ():
@@ -290,7 +321,12 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
290321 except : # noqa: E722
291322 # someone can override the getattr method and raise an exception. I'm looking at you wrapt
292323 return
293- function_qualified_name = f"{ file_name } :{ (class_name + ':' if class_name else '' )} { code .co_name } "
324+
325+ try :
326+ function_qualified_name = f"{ file_name } :{ code .co_qualname } "
327+ except AttributeError :
328+ function_qualified_name = f"{ file_name } :{ (class_name + ':' if class_name else '' )} { code .co_name } "
329+
294330 if function_qualified_name in self .ignored_qualified_functions :
295331 return
296332 if function_qualified_name not in self .function_count :
@@ -323,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
323359
324360 # TODO: Also check if this function arguments are unique from the values logged earlier
325361
326- cur = self .con .cursor ()
362+ with self ._db_lock :
363+ # Check connection again inside lock, in case __exit__ closed it.
364+ if self .con is None :
365+ return
327366
328- t_ns = time .perf_counter_ns ()
329- original_recursion_limit = sys .getrecursionlimit ()
330- try :
331- # pickling can be a recursive operator, so we need to increase the recursion limit
332- sys .setrecursionlimit (10000 )
333- # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
334- # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
335- # leaks, bad references or side effects when unpickling.
336- arguments = dict (arguments .items ())
337- if class_name and code .co_name == "__init__" :
338- del arguments ["self" ]
339- local_vars = pickle .dumps (arguments , protocol = pickle .HIGHEST_PROTOCOL )
340- sys .setrecursionlimit (original_recursion_limit )
341- except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
342- # we retry with dill if pickle fails. It's slower but more comprehensive
367+ cur = self .con .cursor ()
368+
369+ t_ns = time .perf_counter_ns ()
370+ original_recursion_limit = sys .getrecursionlimit ()
343371 try :
344- local_vars = dill .dumps (arguments , protocol = dill .HIGHEST_PROTOCOL )
372+ # pickling can be a recursive operator, so we need to increase the recursion limit
373+ sys .setrecursionlimit (10000 )
374+ # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
375+ # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
376+ # leaks, bad references or side effects when unpickling.
377+ arguments_copy = dict (arguments .items ()) # Use the local 'arguments' from frame.f_locals
378+ if class_name and code .co_name == "__init__" and "self" in arguments_copy :
379+ del arguments_copy ["self" ]
380+ local_vars = pickle .dumps (arguments_copy , protocol = pickle .HIGHEST_PROTOCOL )
345381 sys .setrecursionlimit (original_recursion_limit )
382+ except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
383+ # we retry with dill if pickle fails. It's slower but more comprehensive
384+ try :
385+ sys .setrecursionlimit (10000 ) # Ensure limit is high for dill too
386+ # arguments_copy should be used here as well if defined above
387+ local_vars = dill .dumps (
388+ arguments_copy if "arguments_copy" in locals () else dict (arguments .items ()),
389+ protocol = dill .HIGHEST_PROTOCOL ,
390+ )
391+ sys .setrecursionlimit (original_recursion_limit )
392+
393+ except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
394+ self .function_count [function_qualified_name ] -= 1
395+ return
346396
347- except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
348- # give up
349- self .function_count [function_qualified_name ] -= 1
350- return
351- cur .execute (
352- "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
353- (
354- event ,
355- code .co_name ,
356- class_name ,
357- str (file_name ),
358- frame .f_lineno ,
359- frame .f_back .__hash__ (),
360- t_ns ,
361- local_vars ,
362- ),
363- )
364- self .trace_count += 1
365- self .next_insert -= 1
366- if self .next_insert == 0 :
367- self .next_insert = 1000
368- self .con .commit ()
397+ cur .execute (
398+ "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
399+ (
400+ event ,
401+ code .co_name ,
402+ class_name ,
403+ str (file_name ),
404+ frame .f_lineno ,
405+ frame .f_back .__hash__ (),
406+ t_ns ,
407+ local_vars ,
408+ ),
409+ )
410+ self .trace_count += 1
411+ self .next_insert -= 1
412+ if self .next_insert == 0 :
413+ self .next_insert = 1000
414+ self .con .commit ()
369415
370416 def trace_callback (self , frame : FrameType , event : str , arg : str | None ) -> None :
371417 # profiler section
@@ -413,7 +459,7 @@ def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
413459 class_name = arguments ["self" ].__class__ .__name__
414460 elif "cls" in arguments and hasattr (arguments ["cls" ], "__name__" ):
415461 class_name = arguments ["cls" ].__name__
416- except Exception : # noqa: S110
462+ except Exception : # noqa: BLE001, S110
417463 pass
418464
419465 fn = (fcode .co_filename , fcode .co_firstlineno , fcode .co_name , class_name )
@@ -425,7 +471,7 @@ def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
425471 else :
426472 timings [fn ] = 0 , 0 , 0 , 0 , {}
427473 return 1 # noqa: TRY300
428- except Exception :
474+ except Exception : # noqa: BLE001
429475 # Handle any errors gracefully
430476 return 0
431477
@@ -488,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
488534 cc = cc + 1
489535
490536 if pfn in callers :
491- callers [pfn ] = callers [pfn ] + 1 # TODO: gather more
492- # stats such as the amount of time added to ct courtesy
537+ # Increment call count between these functions
538+ callers [pfn ] = callers [pfn ] + 1
539+ # Note: This tracks stats such as the amount of time added to ct
493540 # of this specific call, and the contribution to cc
494541 # courtesy of this call.
495542 else :
@@ -579,7 +626,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None:
579626
580627 # Store with new format
581628 new_stats [new_func ] = (cc , nc , tt , ct , new_callers )
582- except Exception as e :
629+ except Exception as e : # noqa: BLE001
583630 console .print (f"Error converting stats for { func } : { e } " )
584631 continue
585632
@@ -616,7 +663,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None:
616663 new_callers [new_caller_func ] = count
617664
618665 new_timings [new_func ] = (cc , ns , tt , ct , new_callers )
619- except Exception as e :
666+ except Exception as e : # noqa: BLE001
620667 console .print (f"Error converting timings for { func } : { e } " )
621668 continue
622669
@@ -686,7 +733,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None:
686733
687734 console .print (Align .center (table ))
688735
689- except Exception as e :
736+ except Exception as e : # noqa: BLE001
690737 console .print (f"[bold red]Error in stats processing:[/bold red] { e } " )
691738 console .print (f"Traced { self .trace_count :,} function calls" )
692739 self .total_tt = 0
@@ -716,7 +763,7 @@ def create_stats(self) -> None:
716763
717764 def snapshot_stats (self ) -> None :
718765 self .stats = {}
719- for func , (cc , _ns , tt , ct , caller_dict ) in self .timings .items ():
766+ for func , (cc , _ns , tt , ct , caller_dict ) in list ( self .timings .items () ):
720767 callers = caller_dict .copy ()
721768 nc = 0
722769 for callcnt in callers .values ():
0 commit comments