diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py index a2d080283..dbfedb5ca 100644 --- a/codeflash/benchmarking/codeflash_trace.py +++ b/codeflash/benchmarking/codeflash_trace.py @@ -12,12 +12,17 @@ class CodeflashTrace: """Decorator class that traces and profiles function execution.""" - def __init__(self) -> None: + def __init__(self, trace_path: str = None) -> None: self.function_calls_data = [] self.function_call_count = 0 self.pickle_count_limit = 1000 self._connection = None - self._trace_path = None + self._trace_path = trace_path + if self._trace_path: + self._initialize_db_connection() + self.cur = None + if self._connection: + self.cur = self._connection.cursor() def setup(self, trace_path: str) -> None: """Set up the database connection for direct writing. @@ -47,35 +52,11 @@ def setup(self, trace_path: str) -> None: raise def write_function_timings(self) -> None: - """Write function call data directly to the database. - - Args: - data: List of function call data tuples to write - - """ - if not self.function_calls_data: - return # No data to write - - if self._connection is None and self._trace_path is not None: - self._connection = sqlite3.connect(self._trace_path) + """Write function call data directly to the database.""" + if self._connection is None or self.cur is None: + return # No connection to write data - try: - cur = self._connection.cursor() - # Insert data into the benchmark_function_timings table - cur.executemany( - "INSERT INTO benchmark_function_timings" - "(function_name, class_name, module_name, file_path, benchmark_function_name, " - "benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - self.function_calls_data - ) - self._connection.commit() - self.function_calls_data = [] - except Exception as e: - print(f"Error writing to function timings database: {e}") - if self._connection: - self._connection.rollback() - raise + self._write_batch_and_clear() def open(self) -> None: """Open the database connection.""" @@ -98,6 +79,7 @@ def __call__(self, func: Callable) -> Callable: The wrapped function """ + @functools.wraps(func) def wrapper(*args, **kwargs): # Measure execution time @@ -152,12 +134,60 @@ def wrapper(*args, **kwargs): overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, pickled_args, pickled_kwargs) + ( + func.__name__, + class_name, + func.__module__, + func.__code__.co_filename, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + pickled_args, + pickled_kwargs, + ) ) return result + return wrapper + def _initialize_db_connection(self): + if self._connection is None and self._trace_path is not None: + self._connection = sqlite3.connect(self._trace_path) + + def _pickle_args_kwargs(self, args, kwargs): + try: + pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) + pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) + except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError): + try: + pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) + pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) + except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e: + print(f"Error pickling arguments: {e}") + return None, None + return pickled_args, pickled_kwargs + + def _write_batch_and_clear(self): + if not self.function_calls_data: + return # No data to write + try: + self.cur.executemany( + "INSERT INTO benchmark_function_timings" + "(function_name, class_name, module_name, file_path, benchmark_function_name, " + "benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + self.function_calls_data, + ) + self._connection.commit() + self.function_calls_data = [] + except Exception as e: + print(f"Error writing to function timings database: {e}") + if self._connection: + self._connection.rollback() + raise + + # Create a singleton instance codeflash_trace = CodeflashTrace()