Skip to content

Commit 841f55b

Browse files
committed
normalize for trace and replay tests too
1 parent ceec0ed commit 841f55b

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sqlite3
55
import threading
66
import time
7+
from pathlib import Path
78
from typing import Any, Callable
89

910
from codeflash.picklepatch.pickle_patcher import PicklePatcher
@@ -143,12 +144,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
143144
print("Pickle limit reached")
144145
self._thread_local.active_functions.remove(func_id)
145146
overhead_time = time.thread_time_ns() - end_time
147+
normalized_file_path = Path(func.__code__.co_filename).as_posix()
146148
self.function_calls_data.append(
147149
(
148150
func.__name__,
149151
class_name,
150152
func.__module__,
151-
func.__code__.co_filename,
153+
normalized_file_path,
152154
benchmark_function_name,
153155
benchmark_module_path,
154156
benchmark_line_number,
@@ -169,12 +171,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
169171
# Add to the list of function calls without pickled args. Used for timing info only
170172
self._thread_local.active_functions.remove(func_id)
171173
overhead_time = time.thread_time_ns() - end_time
174+
normalized_file_path = Path(func.__code__.co_filename).as_posix()
172175
self.function_calls_data.append(
173176
(
174177
func.__name__,
175178
class_name,
176179
func.__module__,
177-
func.__code__.co_filename,
180+
normalized_file_path,
178181
benchmark_function_name,
179182
benchmark_module_path,
180183
benchmark_line_number,
@@ -192,12 +195,13 @@ def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
192195
# Add to the list of function calls with pickled args, to be used for replay tests
193196
self._thread_local.active_functions.remove(func_id)
194197
overhead_time = time.thread_time_ns() - end_time
198+
normalized_file_path = Path(func.__code__.co_filename).as_posix()
195199
self.function_calls_data.append(
196200
(
197201
func.__name__,
198202
class_name,
199203
func.__module__,
200-
func.__code__.co_filename,
204+
normalized_file_path,
201205
benchmark_function_name,
202206
benchmark_module_path,
203207
benchmark_line_number,

codeflash/benchmarking/replay_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,25 @@ def get_next_arg_and_return(
2929
db = sqlite3.connect(trace_file)
3030
cur = db.cursor()
3131
limit = num_to_get
32+
33+
normalized_file_path = Path(file_path).as_posix()
3234

3335
if class_name is not None:
3436
cursor = cur.execute(
3537
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
36-
(benchmark_function_name, function_name, file_path, class_name, limit),
38+
(benchmark_function_name, function_name, normalized_file_path, class_name, limit),
3739
)
3840
else:
3941
cursor = cur.execute(
4042
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
41-
(benchmark_function_name, function_name, file_path, limit),
43+
(benchmark_function_name, function_name, normalized_file_path, limit),
4244
)
4345

44-
while (val := cursor.fetchone()) is not None:
45-
yield val[9], val[10] # pickled_args, pickled_kwargs
46+
try:
47+
while (val := cursor.fetchone()) is not None:
48+
yield val[9], val[10] # pickled_args, pickled_kwargs
49+
finally:
50+
db.close()
4651

4752

4853
def get_function_alias(module: str, function_name: str) -> str:

tests/test_pickle_patcher.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
1919
from codeflash.optimization.optimizer import Optimizer
2020
from codeflash.verification.equivalence import compare_test_results
21+
import time
2122

2223
try:
2324
import sqlalchemy
@@ -156,6 +157,9 @@ def test_picklepatch_with_database_connection():
156157
with pytest.raises(PicklePlaceholderAccessError):
157158
reloaded["connection"].execute("SELECT 1")
158159

160+
cursor.close()
161+
conn.close()
162+
159163

160164
def test_picklepatch_with_generator():
161165
"""Test that a data structure containing a generator is replaced by
@@ -290,6 +294,7 @@ def test_run_and_parse_picklepatch() -> None:
290294

291295
# Close the connection to allow file cleanup on Windows
292296
conn.close()
297+
time.sleep(1)
293298

294299
# Handle the case where function runs too fast to be measured
295300
unused_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"]
@@ -326,7 +331,9 @@ def test_run_and_parse_picklepatch() -> None:
326331
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
327332
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
328333
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
329-
conn.close()
334+
conn.close()
335+
336+
time.sleep(1)
330337

331338
# Generate replay test
332339
generate_replay_test(output_file, replay_tests_dir)

0 commit comments

Comments
 (0)