33import  pickle 
44import  sqlite3 
55import  sys 
6+ import  threading 
67import  time 
78from  typing  import  Callable 
89
@@ -18,6 +19,8 @@ def __init__(self) -> None:
1819        self .pickle_count_limit  =  1000 
1920        self ._connection  =  None 
2021        self ._trace_path  =  None 
22+         self ._thread_local  =  threading .local ()
23+         self ._thread_local .active_functions  =  set ()
2124
2225    def  setup (self , trace_path : str ) ->  None :
2326        """Set up the database connection for direct writing. 
@@ -98,23 +101,29 @@ def __call__(self, func: Callable) -> Callable:
98101            The wrapped function 
99102
100103        """ 
104+         func_id  =  (func .__module__ ,func .__name__ )
101105        @functools .wraps (func ) 
102106        def  wrapper (* args , ** kwargs ):
107+             # Initialize thread-local active functions set if it doesn't exist 
108+             if  not  hasattr (self ._thread_local , "active_functions" ):
109+                 self ._thread_local .active_functions  =  set ()
110+             # If it's in a recursive function, just return the result 
111+             if  func_id  in  self ._thread_local .active_functions :
112+                 return  func (* args , ** kwargs )
113+             # Track active functions so we can detect recursive functions 
114+             self ._thread_local .active_functions .add (func_id )
103115            # Measure execution time 
104116            start_time  =  time .thread_time_ns ()
105117            result  =  func (* args , ** kwargs )
106118            end_time  =  time .thread_time_ns ()
107119            # Calculate execution time 
108120            execution_time  =  end_time  -  start_time 
109- 
110121            self .function_call_count  +=  1 
111122
112-             # Measure overhead 
113-             original_recursion_limit  =  sys .getrecursionlimit ()
114123            # Check if currently in pytest benchmark fixture 
115124            if  os .environ .get ("CODEFLASH_BENCHMARKING" , "False" ) ==  "False" :
125+                 self ._thread_local .active_functions .remove (func_id )
116126                return  result 
117- 
118127            # Get benchmark info from environment 
119128            benchmark_function_name  =  os .environ .get ("CODEFLASH_BENCHMARK_FUNCTION_NAME" , "" )
120129            benchmark_module_path  =  os .environ .get ("CODEFLASH_BENCHMARK_MODULE_PATH" , "" )
@@ -125,32 +134,54 @@ def wrapper(*args, **kwargs):
125134            if  "."  in  qualname :
126135                class_name  =  qualname .split ("." )[0 ]
127136
128-             if  self .function_call_count  <=  self .pickle_count_limit :
137+             # Limit pickle count so memory does not explode 
138+             if  self .function_call_count  >  self .pickle_count_limit :
139+                 print ("Pickle limit reached" )
140+                 self ._thread_local .active_functions .remove (func_id )
141+                 overhead_time  =  time .thread_time_ns () -  end_time 
142+                 self .function_calls_data .append (
143+                     (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
144+                      benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
145+                      overhead_time , None , None )
146+                 )
147+                 return  result 
148+ 
149+             try :
150+                 original_recursion_limit  =  sys .getrecursionlimit ()
151+                 sys .setrecursionlimit (10000 )
152+                 # args = dict(args.items()) 
153+                 # if class_name and func.__name__ == "__init__" and "self" in args: 
154+                 #     del args["self"] 
155+                 # Pickle the arguments 
156+                 pickled_args  =  pickle .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
157+                 pickled_kwargs  =  pickle .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
158+                 sys .setrecursionlimit (original_recursion_limit )
159+             except  (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
160+                 # Retry with dill if pickle fails. It's slower but more comprehensive 
129161                try :
130-                     sys .setrecursionlimit (1000000 )
131-                     args  =  dict (args .items ())
132-                     if  class_name  and  func .__name__  ==  "__init__"  and  "self"  in  args :
133-                         del  args ["self" ]
134-                     # Pickle the arguments 
135-                     pickled_args  =  pickle .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
136-                     pickled_kwargs  =  pickle .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
162+                     pickled_args  =  dill .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
163+                     pickled_kwargs  =  dill .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
137164                    sys .setrecursionlimit (original_recursion_limit )
138-                 except  (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
139-                     # we retry with dill if pickle fails. It's slower but more comprehensive 
140-                     try :
141-                         pickled_args  =  dill .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
142-                         pickled_kwargs  =  dill .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
143-                         sys .setrecursionlimit (original_recursion_limit )
144- 
145-                     except  (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ) as  e :
146-                         print (f"Error pickling arguments for function { func .__name__ }  : { e }  " )
147-                         return  result 
148165
166+                 except  (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ) as  e :
167+                     print (f"Error pickling arguments for function { func .__name__ }  : { e }  " )
168+                     # Add to the list of function calls without pickled args. Used for timing info only 
169+                     self ._thread_local .active_functions .remove (func_id )
170+                     overhead_time  =  time .thread_time_ns () -  end_time 
171+                     self .function_calls_data .append (
172+                         (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
173+                          benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
174+                          overhead_time , None , None )
175+                     )
176+                     return  result 
177+ 
178+             # Flush to database every 1000 calls 
149179            if  len (self .function_calls_data ) >  1000 :
150180                self .write_function_timings ()
151-             # Calculate overhead time 
152-             overhead_time  =  time .thread_time_ns () -  end_time 
153181
182+             # Add to the list of function calls with pickled args, to be used for replay tests 
183+             self ._thread_local .active_functions .remove (func_id )
184+             overhead_time  =  time .thread_time_ns () -  end_time 
154185            self .function_calls_data .append (
155186                (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
156187                 benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
0 commit comments