Skip to content

Commit 62d6d00

Browse files
committed
add database lock
1 parent df88db4 commit 62d6d00

File tree

1 file changed

+100
-62
lines changed

1 file changed

+100
-62
lines changed

codeflash/tracer.py

Lines changed: 100 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#
1212
from __future__ import annotations
1313

14+
import contextlib
1415
import importlib.machinery
1516
import io
1617
import json
@@ -42,7 +43,6 @@
4243
from codeflash.tracing.replay_test import create_trace_replay_test
4344
from codeflash.tracing.tracing_utils import FunctionModules
4445
from codeflash.verification.verification_utils import get_test_file_path
45-
import contextlib
4646

4747
if TYPE_CHECKING:
4848
from types import FrameType, TracebackType
@@ -100,6 +100,7 @@ def __init__(
100100
)
101101
disable = True
102102
self.disable = disable
103+
self._db_lock: threading.Lock | None = None
103104
if self.disable:
104105
return
105106
if sys.getprofile() is not None or sys.gettrace() is not None:
@@ -109,6 +110,9 @@ def __init__(
109110
)
110111
self.disable = True
111112
return
113+
114+
self._db_lock = threading.Lock()
115+
112116
self.con = None
113117
self.output_file = Path(output).resolve()
114118
self.functions = functions
@@ -180,34 +184,55 @@ def __enter__(self) -> None:
180184
def __exit__(
181185
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
182186
) -> None:
183-
if self.disable:
187+
if self.disable or self._db_lock is None:
184188
return
185189
sys.setprofile(None)
186-
self.con.commit()
187-
console.rule("Codeflash: Traced Program Output End", style="bold blue")
188-
self.create_stats()
190+
threading.setprofile(None)
189191

190-
cur = self.con.cursor()
191-
cur.execute(
192-
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
193-
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
194-
"cumulative_time_ns INTEGER, callers BLOB)"
195-
)
196-
for func, (cc, nc, tt, ct, callers) in self.stats.items():
197-
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
192+
with self._db_lock:
193+
if self.con is None:
194+
return
195+
196+
self.con.commit() # Commit any pending from tracer_logic
197+
console.rule("Codeflash: Traced Program Output End", style="bold blue")
198+
self.create_stats() # This calls snapshot_stats which uses self.timings
199+
200+
cur = self.con.cursor()
198201
cur.execute(
199-
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
200-
(str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
202+
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
203+
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
204+
"cumulative_time_ns INTEGER, callers BLOB)"
201205
)
202-
self.con.commit()
206+
# self.stats is populated by snapshot_stats() called within create_stats()
207+
# Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
208+
# For now, assuming self.stats is primarily in-memory after create_stats()
209+
for func, (cc, nc, tt, ct, callers) in self.stats.items():
210+
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
211+
cur.execute(
212+
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
213+
(
214+
str(Path(func[0]).resolve()),
215+
func[1],
216+
func[2],
217+
func[3],
218+
cc,
219+
nc,
220+
tt,
221+
ct,
222+
json.dumps(remapped_callers),
223+
),
224+
)
225+
self.con.commit()
203226

204-
self.make_pstats_compatible()
205-
self.print_stats("tottime")
206-
cur = self.con.cursor()
207-
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
208-
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
209-
self.con.commit()
210-
self.con.close()
227+
self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory
228+
self.print_stats("tottime") # Uses self.stats, prints to console
229+
230+
cur = self.con.cursor() # New cursor
231+
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
232+
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
233+
self.con.commit()
234+
self.con.close()
235+
self.con = None # Mark connection as closed
211236

212237
# filter any functions where we did not capture the return
213238
self.function_modules = [
@@ -252,6 +277,9 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
252277
threading.setprofile(None)
253278
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
254279
return
280+
if self.disable or self._db_lock is None or self.con is None:
281+
return
282+
255283
code = frame.f_code
256284

257285
# Check function name first before resolving path
@@ -331,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
331359

332360
# TODO: Also check if this function arguments are unique from the values logged earlier
333361

334-
cur = self.con.cursor()
362+
with self._db_lock:
363+
# Check connection again inside lock, in case __exit__ closed it.
364+
if self.con is None:
365+
return
335366

336-
t_ns = time.perf_counter_ns()
337-
original_recursion_limit = sys.getrecursionlimit()
338-
try:
339-
# pickling can be a recursive operator, so we need to increase the recursion limit
340-
sys.setrecursionlimit(10000)
341-
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
342-
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
343-
# leaks, bad references or side effects when unpickling.
344-
arguments = dict(arguments.items())
345-
if class_name and code.co_name == "__init__":
346-
del arguments["self"]
347-
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
348-
sys.setrecursionlimit(original_recursion_limit)
349-
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
350-
# we retry with dill if pickle fails. It's slower but more comprehensive
367+
cur = self.con.cursor()
368+
369+
t_ns = time.perf_counter_ns()
370+
original_recursion_limit = sys.getrecursionlimit()
351371
try:
352-
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
372+
# pickling can be a recursive operator, so we need to increase the recursion limit
373+
sys.setrecursionlimit(10000)
374+
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
375+
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
376+
# leaks, bad references or side effects when unpickling.
377+
arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals
378+
if class_name and code.co_name == "__init__" and "self" in arguments_copy:
379+
del arguments_copy["self"]
380+
local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL)
353381
sys.setrecursionlimit(original_recursion_limit)
382+
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
383+
# we retry with dill if pickle fails. It's slower but more comprehensive
384+
try:
385+
sys.setrecursionlimit(10000) # Ensure limit is high for dill too
386+
# arguments_copy should be used here as well if defined above
387+
local_vars = dill.dumps(
388+
arguments_copy if "arguments_copy" in locals() else dict(arguments.items()),
389+
protocol=dill.HIGHEST_PROTOCOL,
390+
)
391+
sys.setrecursionlimit(original_recursion_limit)
392+
393+
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
394+
self.function_count[function_qualified_name] -= 1
395+
return
354396

355-
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
356-
# give up
357-
self.function_count[function_qualified_name] -= 1
358-
return
359-
cur.execute(
360-
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
361-
(
362-
event,
363-
code.co_name,
364-
class_name,
365-
str(file_name),
366-
frame.f_lineno,
367-
frame.f_back.__hash__(),
368-
t_ns,
369-
local_vars,
370-
),
371-
)
372-
self.trace_count += 1
373-
self.next_insert -= 1
374-
if self.next_insert == 0:
375-
self.next_insert = 1000
376-
self.con.commit()
397+
cur.execute(
398+
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
399+
(
400+
event,
401+
code.co_name,
402+
class_name,
403+
str(file_name),
404+
frame.f_lineno,
405+
frame.f_back.__hash__(),
406+
t_ns,
407+
local_vars,
408+
),
409+
)
410+
self.trace_count += 1
411+
self.next_insert -= 1
412+
if self.next_insert == 0:
413+
self.next_insert = 1000
414+
self.con.commit()
377415

378416
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
379417
# profiler section

0 commit comments

Comments
 (0)