66import time
77from typing import Callable
88
9+ from codeflash .cli_cmds .cli import logger
910from codeflash .picklepatch .pickle_patcher import PicklePatcher
1011
1112
@@ -42,10 +43,8 @@ def setup(self, trace_path: str) -> None:
4243 )
4344 self ._connection .commit ()
4445 except Exception as e :
45- print (f"Database setup error: { e } " )
46- if self ._connection :
47- self ._connection .close ()
48- self ._connection = None
46+ logger .error (f"Database setup error: { e } " )
47+ self .close ()
4948 raise
5049
5150 def write_function_timings (self ) -> None :
@@ -63,18 +62,17 @@ def write_function_timings(self) -> None:
6362
6463 try :
6564 cur = self ._connection .cursor ()
66- # Insert data into the benchmark_function_timings table
6765 cur .executemany (
6866 "INSERT INTO benchmark_function_timings"
6967 "(function_name, class_name, module_name, file_path, benchmark_function_name, "
7068 "benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
7169 "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ,
72- self .function_calls_data
70+ self .function_calls_data ,
7371 )
7472 self ._connection .commit ()
7573 self .function_calls_data = []
7674 except Exception as e :
77- print (f"Error writing to function timings database: { e } " )
75+ logger . error (f"Error writing to function timings database: { e } " )
7876 if self ._connection :
7977 self ._connection .rollback ()
8078 raise
@@ -100,9 +98,10 @@ def __call__(self, func: Callable) -> Callable:
10098 The wrapped function
10199
102100 """
103- func_id = (func .__module__ ,func .__name__ )
101+ func_id = (func .__module__ , func .__name__ )
102+
104103 @functools .wraps (func )
105- def wrapper (* args , ** kwargs ) :
104+ def wrapper (* args : tuple , ** kwargs : dict ) -> object :
106105 # Initialize thread-local active functions set if it doesn't exist
107106 if not hasattr (self ._thread_local , "active_functions" ):
108107 self ._thread_local .active_functions = set ()
@@ -123,25 +122,33 @@ def wrapper(*args, **kwargs):
123122 if os .environ .get ("CODEFLASH_BENCHMARKING" , "False" ) == "False" :
124123 self ._thread_local .active_functions .remove (func_id )
125124 return result
126- # Get benchmark info from environment
125+
127126 benchmark_function_name = os .environ .get ("CODEFLASH_BENCHMARK_FUNCTION_NAME" , "" )
128127 benchmark_module_path = os .environ .get ("CODEFLASH_BENCHMARK_MODULE_PATH" , "" )
129128 benchmark_line_number = os .environ .get ("CODEFLASH_BENCHMARK_LINE_NUMBER" , "" )
130- # Get class name
131129 class_name = ""
132130 qualname = func .__qualname__
133131 if "." in qualname :
134132 class_name = qualname .split ("." )[0 ]
135133
136- # Limit pickle count so memory does not explode
137134 if self .function_call_count > self .pickle_count_limit :
138- print ( " Pickle limit reached" )
135+ logger . debug ( "CodeflashTrace: Pickle limit reached" )
139136 self ._thread_local .active_functions .remove (func_id )
140137 overhead_time = time .thread_time_ns () - end_time
141138 self .function_calls_data .append (
142- (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
143- benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
144- overhead_time , None , None )
139+ (
140+ func .__name__ ,
141+ class_name ,
142+ func .__module__ ,
143+ func .__code__ .co_filename ,
144+ benchmark_function_name ,
145+ benchmark_module_path ,
146+ benchmark_line_number ,
147+ execution_time ,
148+ overhead_time ,
149+ None ,
150+ None ,
151+ )
145152 )
146153 return result
147154
@@ -150,30 +157,50 @@ def wrapper(*args, **kwargs):
150157 pickled_args = PicklePatcher .dumps (args , protocol = pickle .HIGHEST_PROTOCOL )
151158 pickled_kwargs = PicklePatcher .dumps (kwargs , protocol = pickle .HIGHEST_PROTOCOL )
152159 except Exception as e :
153- print (f"Error pickling arguments for function { func .__name__ } : { e } " )
160+ logger . debug (f"CodeflashTrace: Error pickling arguments for function { func .__name__ } : { e } " )
154161 # Add to the list of function calls without pickled args. Used for timing info only
155162 self ._thread_local .active_functions .remove (func_id )
156163 overhead_time = time .thread_time_ns () - end_time
157164 self .function_calls_data .append (
158- (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
159- benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
160- overhead_time , None , None )
165+ (
166+ func .__name__ ,
167+ class_name ,
168+ func .__module__ ,
169+ func .__code__ .co_filename ,
170+ benchmark_function_name ,
171+ benchmark_module_path ,
172+ benchmark_line_number ,
173+ execution_time ,
174+ overhead_time ,
175+ None ,
176+ None ,
177+ )
161178 )
162179 return result
163- # Flush to database every 100 calls
164180 if len (self .function_calls_data ) > 100 :
165181 self .write_function_timings ()
166182
167183 # Add to the list of function calls with pickled args, to be used for replay tests
168184 self ._thread_local .active_functions .remove (func_id )
169185 overhead_time = time .thread_time_ns () - end_time
170186 self .function_calls_data .append (
171- (func .__name__ , class_name , func .__module__ , func .__code__ .co_filename ,
172- benchmark_function_name , benchmark_module_path , benchmark_line_number , execution_time ,
173- overhead_time , pickled_args , pickled_kwargs )
187+ (
188+ func .__name__ ,
189+ class_name ,
190+ func .__module__ ,
191+ func .__code__ .co_filename ,
192+ benchmark_function_name ,
193+ benchmark_module_path ,
194+ benchmark_line_number ,
195+ execution_time ,
196+ overhead_time ,
197+ pickled_args ,
198+ pickled_kwargs ,
199+ )
174200 )
175201 return result
202+
176203 return wrapper
177204
178- # Create a singleton instance
205+
179206codeflash_trace = CodeflashTrace ()
0 commit comments