Skip to content

Commit 638ac96

Browse files
committed
preview
1 parent f7631ef commit 638ac96

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,17 @@ def __call__(self, func: Callable) -> Callable:
102102

103103
@functools.wraps(func)
104104
def wrapper(*args: tuple, **kwargs: dict) -> object:
105-
# Initialize thread-local active functions set if it doesn't exist
106105
if not hasattr(self._thread_local, "active_functions"):
107106
self._thread_local.active_functions = set()
108107
# If it's in a recursive function, just return the result
109108
if func_id in self._thread_local.active_functions:
110109
return func(*args, **kwargs)
111110
# Track active functions so we can detect recursive functions
112111
self._thread_local.active_functions.add(func_id)
113-
# Measure execution time
114112
start_time = time.thread_time_ns()
115113
result = func(*args, **kwargs)
116114
end_time = time.thread_time_ns()
117-
# Calculate execution time
115+
logger.info(f"CodeflashTrace: Function {func.__name__} executed in {end_time - start_time} ns")
118116
execution_time = end_time - start_time
119117
self.function_call_count += 1
120118

codeflash/benchmarking/plugin/plugin.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
7878
A nested dictionary where:
7979
- Outer keys are module_name.qualified_name (module.class.function)
8080
- Inner keys are of type BenchmarkKey
81-
- Values are function timing in milliseconds
81+
- Values are function timing in nanoseconds
8282
8383
"""
8484
result = {}
@@ -120,16 +120,16 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
120120
return result
121121

122122
@staticmethod
123-
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
123+
def get_benchmark_timings(trace_path: Path) -> dict[tuple[str, str, int], int]:
124124
"""Extract total benchmark timings from trace files.
125125
126126
Args:
127127
trace_path: Path to the trace file
128128
129129
Returns:
130130
A dictionary mapping where:
131-
- Keys are of type BenchmarkKey
132-
- Values are total benchmark timing in milliseconds (with overhead subtracted)
131+
- Keys are (module_path, function_name, line_number)
132+
- Values are total benchmark timing in nanoseconds (with overhead subtracted)
133133
134134
"""
135135
result = {}
@@ -148,8 +148,8 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
148148

149149
for row in cursor.fetchall():
150150
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
151-
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
152-
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
151+
key = (benchmark_file, benchmark_func, benchmark_line)
152+
overhead_by_benchmark[key] = total_overhead_ns or 0 # Handle NULL sum case
153153

154154
# Query the benchmark_timings table for total times
155155
cursor.execute(
@@ -159,13 +159,9 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
159159

160160
for row in cursor.fetchall():
161161
benchmark_file, benchmark_func, benchmark_line, time_ns = row
162-
163-
benchmark_key = BenchmarkKey(
164-
module_path=benchmark_file, function_name=benchmark_func
165-
) # (file::function::line)
166-
# Subtract overhead from total time
167-
overhead = overhead_by_benchmark.get(benchmark_key, 0)
168-
result[benchmark_key] = time_ns - overhead
162+
key = (benchmark_file, benchmark_func, benchmark_line)
163+
overhead = overhead_by_benchmark.get(key, 0)
164+
result[key] = time_ns - overhead
169165

170166
finally:
171167
connection.close()
@@ -245,7 +241,6 @@ def _run_benchmark(
245241
else:
246242
call_identifier = f"{benchmark_function_name}::call_{self._call_count}"
247243

248-
os.environ["CODEFLASH_BENCHMARKING"] = "True"
249244
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = call_identifier
250245
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
251246
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)

0 commit comments

Comments
 (0)