33import pickle
44import sqlite3
55import sys
6+ import threading
67import time
78from typing import Callable
89
@@ -18,6 +19,8 @@ def __init__(self) -> None:
1819 self .pickle_count_limit = 1000
1920 self ._connection = None
2021 self ._trace_path = None
22+ self ._thread_local = threading .local ()
23+ self ._thread_local .active_functions = set ()
2124
2225 def setup (self , trace_path : str ) -> None :
2326 """Set up the database connection for direct writing.
@@ -98,23 +101,29 @@ def __call__(self, func: Callable) -> Callable:
98101 The wrapped function
99102
100103 """
104+ func_id = (func .__module__ ,func .__name__ )
101105 @functools .wraps (func )
102106 def wrapper (* args , ** kwargs ):
107+ # Initialize thread-local active functions set if it doesn't exist
108+ if not hasattr (self ._thread_local , "active_functions" ):
109+ self ._thread_local .active_functions = set ()
110+ # If it's in a recursive function, just return the result
111+ if func_id in self ._thread_local .active_functions :
112+ return func (* args , ** kwargs )
113+ # Track active functions so we can detect recursive functions
114+ self ._thread_local .active_functions .add (func_id )
103115 # Measure execution time
104116 start_time = time .thread_time_ns ()
105117 result = func (* args , ** kwargs )
106118 end_time = time .thread_time_ns ()
107119 # Calculate execution time
108120 execution_time = end_time - start_time
109-
110121 self .function_call_count += 1
111122
112- # Measure overhead
113- original_recursion_limit = sys .getrecursionlimit ()
114123 # Check if currently in pytest benchmark fixture
115124 if os .environ .get ("CODEFLASH_BENCHMARKING" , "False" ) == "False" :
125+ self ._thread_local .active_functions .remove (func_id )
116126 return result
117-
118127 # Get benchmark info from environment
119128 benchmark_function_name = os .environ .get ("CODEFLASH_BENCHMARK_FUNCTION_NAME" , "" )
120129 benchmark_module_path = os .environ .get ("CODEFLASH_BENCHMARK_MODULE_PATH" , "" )
@@ -125,32 +134,54 @@ def wrapper(*args, **kwargs):
125134 if "." in qualname :
126135 class_name = qualname .split ("." )[0 ]
127136
128- if self .function_call_count <= self .pickle_count_limit :
137+ # Limit pickle count so memory does not explode
138+ if self .function_call_count > self .pickle_count_limit :
139+ print ("Pickle limit reached" )
140+ self ._thread_local .active_functions .remove (func_id )
141+ overhead_time = time .thread_time_ns () - end_time
142+ self .function_calls_data .append (
143+ (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
144+ benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
145+ overhead_time , None , None )
146+ )
147+ return result
148+
149+ try :
150+ original_recursion_limit = sys .getrecursionlimit ()
151+ sys .setrecursionlimit (10000 )
152+ # args = dict(args.items())
153+ # if class_name and func.__name__ == "__init__" and "self" in args:
154+ # del args["self"]
155+ # Pickle the arguments
156+ pickled_args = pickle .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
157+ pickled_kwargs = pickle .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
158+ sys .setrecursionlimit (original_recursion_limit )
159+ except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
160+ # Retry with dill if pickle fails. It's slower but more comprehensive
129161 try :
130- sys .setrecursionlimit (1000000 )
131- args = dict (args .items ())
132- if class_name and func .__name__ == "__init__" and "self" in args :
133- del args ["self" ]
134- # Pickle the arguments
135- pickled_args = pickle .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
136- pickled_kwargs = pickle .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
162+ pickled_args = dill .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
163+ pickled_kwargs = dill .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
137164 sys .setrecursionlimit (original_recursion_limit )
138- except (TypeError , pickle .PicklingError , AttributeError , RecursionError , OSError ):
139- # we retry with dill if pickle fails. It's slower but more comprehensive
140- try :
141- pickled_args = dill .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
142- pickled_kwargs = dill .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
143- sys .setrecursionlimit (original_recursion_limit )
144-
145- except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ) as e :
146- print (f"Error pickling arguments for function { func .__name__ } : { e } " )
147- return result
148165
166+ except (TypeError , dill .PicklingError , AttributeError , RecursionError , OSError ) as e :
167+ print (f"Error pickling arguments for function { func .__name__ } : { e } " )
168+ # Add to the list of function calls without pickled args. Used for timing info only
169+ self ._thread_local .active_functions .remove (func_id )
170+ overhead_time = time .thread_time_ns () - end_time
171+ self .function_calls_data .append (
172+ (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
173+ benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
174+ overhead_time , None , None )
175+ )
176+ return result
177+
178+ # Flush to database every 1000 calls
149179 if len (self .function_calls_data ) > 1000 :
150180 self .write_function_timings ()
151- # Calculate overhead time
152- overhead_time = time .thread_time_ns () - end_time
153181
182+ # Add to the list of function calls with pickled args, to be used for replay tests
183+ self ._thread_local .active_functions .remove (func_id )
184+ overhead_time = time .thread_time_ns () - end_time
154185 self .function_calls_data .append (
155186 (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
156187 benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
0 commit comments