1111#
1212from __future__ import annotations
1313
14+ import contextlib
1415import importlib .machinery
1516import io
1617import json
@@ -99,6 +100,7 @@ def __init__(
99100 )
100101 disable = True
101102 self .disable = disable
103+ self ._db_lock : threading .Lock | None = None
102104 if self .disable :
103105 return
104106 if sys .getprofile () is not None or sys .gettrace () is not None :
@@ -108,6 +110,9 @@ def __init__(
108110 )
109111 self .disable = True
110112 return
113+
114+ self ._db_lock = threading .Lock ()
115+
111116 self .con = None
112117 self .output_file = Path (output ).resolve ()
113118 self .functions = functions
@@ -130,6 +135,7 @@ def __init__(
130135 self .timeout = timeout
131136 self .next_insert = 1000
132137 self .trace_count = 0
138+ self .path_cache = {} # Cache for resolved file paths
133139
134140 # Profiler variables
135141 self .bias = 0 # calibration constant
@@ -178,34 +184,55 @@ def __enter__(self) -> None:
178184 def __exit__ (
179185 self , exc_type : type [BaseException ] | None , exc_val : BaseException | None , exc_tb : TracebackType | None
180186 ) -> None :
181- if self .disable :
187+ if self .disable or self . _db_lock is None :
182188 return
183189 sys .setprofile (None )
184- self .con .commit ()
185- console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
186- self .create_stats ()
190+ threading .setprofile (None )
187191
188- cur = self .con .cursor ()
189- cur .execute (
190- "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
191- "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
192- "cumulative_time_ns INTEGER, callers BLOB)"
193- )
194- for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
195- remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
192+ with self ._db_lock :
193+ if self .con is None :
194+ return
195+
196+ self .con .commit () # Commit any pending from tracer_logic
197+ console .rule ("Codeflash: Traced Program Output End" , style = "bold blue" )
198+ self .create_stats () # This calls snapshot_stats which uses self.timings
199+
200+ cur = self .con .cursor ()
196201 cur .execute (
197- "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
198- (str (Path (func [0 ]).resolve ()), func [1 ], func [2 ], func [3 ], cc , nc , tt , ct , json .dumps (remapped_callers )),
202+ "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
203+ "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
204+ "cumulative_time_ns INTEGER, callers BLOB)"
199205 )
200- self .con .commit ()
206+ # self.stats is populated by snapshot_stats() called within create_stats()
207+ # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
208+ # For now, assuming self.stats is primarily in-memory after create_stats()
209+ for func , (cc , nc , tt , ct , callers ) in self .stats .items ():
210+ remapped_callers = [{"key" : k , "value" : v } for k , v in callers .items ()]
211+ cur .execute (
212+ "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
213+ (
214+ str (Path (func [0 ]).resolve ()),
215+ func [1 ],
216+ func [2 ],
217+ func [3 ],
218+ cc ,
219+ nc ,
220+ tt ,
221+ ct ,
222+ json .dumps (remapped_callers ),
223+ ),
224+ )
225+ self .con .commit ()
201226
202- self .make_pstats_compatible ()
203- self .print_stats ("tottime" )
204- cur = self .con .cursor ()
205- cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
206- cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
207- self .con .commit ()
208- self .con .close ()
227+ self .make_pstats_compatible () # Modifies self.stats and self.timings in-memory
228+ self .print_stats ("tottime" ) # Uses self.stats, prints to console
229+
230+ cur = self .con .cursor () # New cursor
231+ cur .execute ("CREATE TABLE total_time (time_ns INTEGER)" )
232+ cur .execute ("INSERT INTO total_time VALUES (?)" , (self .total_tt ,))
233+ self .con .commit ()
234+ self .con .close ()
235+ self .con = None # Mark connection as closed
209236
210237 # filter any functions where we did not capture the return
211238 self .function_modules = [
@@ -245,18 +272,29 @@ def __exit__(
245272 def tracer_logic (self , frame : FrameType , event : str ) -> None : # noqa: PLR0911
246273 if event != "call" :
247274 return
248- if self . timeout is not None and (time .time () - self .start_time ) > self .timeout :
275+ if None is not self . timeout and (time .time () - self .start_time ) > self .timeout :
249276 sys .setprofile (None )
250277 threading .setprofile (None )
251278 console .print (f"Codeflash: Timeout reached! Stopping tracing at { self .timeout } seconds." )
252279 return
253- code = frame .f_code
280+ if self .disable or self ._db_lock is None or self .con is None :
281+ return
254282
255- file_name = Path (code .co_filename ).resolve ()
256- # TODO : It currently doesn't log the last return call from the first function
283+ code = frame .f_code
257284
285+ # Check function name first before resolving path
258286 if code .co_name in self .ignored_functions :
259287 return
288+
289+ # Now resolve file path only if we need it
290+ co_filename = code .co_filename
291+ if co_filename in self .path_cache :
292+ file_name = self .path_cache [co_filename ]
293+ else :
294+ file_name = Path (co_filename ).resolve ()
295+ self .path_cache [co_filename ] = file_name
296+ # TODO : It currently doesn't log the last return call from the first function
297+
260298 if not file_name .is_relative_to (self .project_root ):
261299 return
262300 if not file_name .exists ():
@@ -266,18 +304,29 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
266304 class_name = None
267305 arguments = frame .f_locals
268306 try :
269- if (
270- "self" in arguments
271- and hasattr (arguments ["self" ], "__class__" )
272- and hasattr (arguments ["self" ].__class__ , "__name__" )
273- ):
274- class_name = arguments ["self" ].__class__ .__name__
275- elif "cls" in arguments and hasattr (arguments ["cls" ], "__name__" ):
276- class_name = arguments ["cls" ].__name__
307+ self_arg = arguments .get ("self" )
308+ if self_arg is not None :
309+ try :
310+ class_name = self_arg .__class__ .__name__
311+ except AttributeError :
312+ cls_arg = arguments .get ("cls" )
313+ if cls_arg is not None :
314+ with contextlib .suppress (AttributeError ):
315+ class_name = cls_arg .__name__
316+ else :
317+ cls_arg = arguments .get ("cls" )
318+ if cls_arg is not None :
319+ with contextlib .suppress (AttributeError ):
320+ class_name = cls_arg .__name__
277321 except : # noqa: E722
278322 # someone can override the getattr method and raise an exception. I'm looking at you wrapt
279323 return
280- function_qualified_name = f"{ file_name } :{ (class_name + ':' if class_name else '' )} { code .co_name } "
324+
325+ try :
326+ function_qualified_name = f"{ file_name } :{ code .co_qualname } "
327+ except AttributeError :
328+ function_qualified_name = f"{ file_name } :{ (class_name + ':' if class_name else '' )} { code .co_name } "
329+
281330 if function_qualified_name in self .ignored_qualified_functions :
282331 return
283332 if function_qualified_name not in self .function_count :
@@ -310,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
310359
311360 # TODO: Also check if this function arguments are unique from the values logged earlier
312361
313- cur = self .con .cursor ()
362+ with self ._db_lock :
363+ # Check connection again inside lock, in case __exit__ closed it.
364+ if self .con is None :
365+ return
314366
315- t_ns = time .perf_counter_ns ()
316- original_recursion_limit = sys .getrecursionlimit ()
317- try :
318- # pickling can be a recursive operator, so we need to increase the recursion limit
319- sys .setrecursionlimit (10000 )
320- # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
321- # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
322- # leaks, bad references or side effects when unpickling.
323- arguments = dict (arguments .items ())
324- if class_name and code .co_name == "__init__" :
325- del arguments ["self" ]
326- local_vars = pickle .dumps (arguments , protocol = pickle .HIGHEST_PROTOCOL )
327- sys .setrecursionlimit (original_recursion_limit )
328- except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
329- # we retry with dill if pickle fails. It's slower but more comprehensive
367+ cur = self .con .cursor ()
368+
369+ t_ns = time .perf_counter_ns ()
370+ original_recursion_limit = sys .getrecursionlimit ()
330371 try :
331- local_vars = dill .dumps (arguments , protocol = dill .HIGHEST_PROTOCOL )
372+ # pickling can be a recursive operator, so we need to increase the recursion limit
373+ sys .setrecursionlimit (10000 )
374+ # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
375+ # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
376+ # leaks, bad references or side effects when unpickling.
377+ arguments_copy = dict (arguments .items ()) # Use the local 'arguments' from frame.f_locals
378+ if class_name and code .co_name == "__init__" and "self" in arguments_copy :
379+ del arguments_copy ["self" ]
380+ local_vars = pickle .dumps (arguments_copy , protocol = pickle .HIGHEST_PROTOCOL )
332381 sys .setrecursionlimit (original_recursion_limit )
382+ except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
383+ # we retry with dill if pickle fails. It's slower but more comprehensive
384+ try :
385+ sys .setrecursionlimit (10000 ) # Ensure limit is high for dill too
386+ # arguments_copy should be used here as well if defined above
387+ local_vars = dill .dumps (
388+ arguments_copy if "arguments_copy" in locals () else dict (arguments .items ()),
389+ protocol = dill .HIGHEST_PROTOCOL ,
390+ )
391+ sys .setrecursionlimit (original_recursion_limit )
392+
393+ except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
394+ self .function_count [function_qualified_name ] -= 1
395+ return
333396
334- except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ):
335- # give up
336- self .function_count [function_qualified_name ] -= 1
337- return
338- cur .execute (
339- "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
340- (
341- event ,
342- code .co_name ,
343- class_name ,
344- str (file_name ),
345- frame .f_lineno ,
346- frame .f_back .__hash__ (),
347- t_ns ,
348- local_vars ,
349- ),
350- )
351- self .trace_count += 1
352- self .next_insert -= 1
353- if self .next_insert == 0 :
354- self .next_insert = 1000
355- self .con .commit ()
397+ cur .execute (
398+ "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)" ,
399+ (
400+ event ,
401+ code .co_name ,
402+ class_name ,
403+ str (file_name ),
404+ frame .f_lineno ,
405+ frame .f_back .__hash__ (),
406+ t_ns ,
407+ local_vars ,
408+ ),
409+ )
410+ self .trace_count += 1
411+ self .next_insert -= 1
412+ if self .next_insert == 0 :
413+ self .next_insert = 1000
414+ self .con .commit ()
356415
357416 def trace_callback (self , frame : FrameType , event : str , arg : str | None ) -> None :
358417 # profiler section
@@ -475,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
475534 cc = cc + 1
476535
477536 if pfn in callers :
478- callers [pfn ] = callers [pfn ] + 1 # TODO: gather more
479- # stats such as the amount of time added to ct courtesy
537+ # Increment call count between these functions
538+ callers [pfn ] = callers [pfn ] + 1
539+ # Note: This tracks stats such as the amount of time added to ct
480540 # of this specific call, and the contribution to cc
481541 # courtesy of this call.
482542 else :
@@ -703,7 +763,7 @@ def create_stats(self) -> None:
703763
704764 def snapshot_stats (self ) -> None :
705765 self .stats = {}
706- for func , (cc , _ns , tt , ct , caller_dict ) in self .timings .items ():
766+ for func , (cc , _ns , tt , ct , caller_dict ) in list ( self .timings .items () ):
707767 callers = caller_dict .copy ()
708768 nc = 0
709769 for callcnt in callers .values ():
0 commit comments