Skip to content

Commit d64462a

Browse files
committed
Update tracer.py
1 parent 7c56c1a commit d64462a

File tree

1 file changed

+122
-75
lines changed

1 file changed

+122
-75
lines changed

codeflash/tracer.py

Lines changed: 122 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#
1212
from __future__ import annotations
1313

14+
import contextlib
1415
import importlib.machinery
1516
import io
1617
import json
@@ -42,7 +43,6 @@
4243
from codeflash.tracing.replay_test import create_trace_replay_test
4344
from codeflash.tracing.tracing_utils import FunctionModules
4445
from codeflash.verification.verification_utils import get_test_file_path
45-
import contextlib
4646

4747
if TYPE_CHECKING:
4848
from types import FrameType, TracebackType
@@ -77,7 +77,7 @@ def __init__(
7777
self,
7878
output: str = "codeflash.trace",
7979
functions: list[str] | None = None,
80-
disable: bool = False, # noqa: FBT001, FBT002
80+
disable: bool = False,
8181
config_file_path: Path | None = None,
8282
max_function_count: int = 256,
8383
timeout: int | None = None, # seconds
@@ -100,6 +100,7 @@ def __init__(
100100
)
101101
disable = True
102102
self.disable = disable
103+
self._db_lock: threading.Lock | None = None
103104
if self.disable:
104105
return
105106
if sys.getprofile() is not None or sys.gettrace() is not None:
@@ -109,6 +110,9 @@ def __init__(
109110
)
110111
self.disable = True
111112
return
113+
114+
self._db_lock = threading.Lock()
115+
112116
self.con = None
113117
self.output_file = Path(output).resolve()
114118
self.functions = functions
@@ -180,34 +184,55 @@ def __enter__(self) -> None:
180184
def __exit__(
181185
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
182186
) -> None:
183-
if self.disable:
187+
if self.disable or self._db_lock is None:
184188
return
185189
sys.setprofile(None)
186-
self.con.commit()
187-
console.rule("Codeflash: Traced Program Output End", style="bold blue")
188-
self.create_stats()
190+
threading.setprofile(None)
189191

190-
cur = self.con.cursor()
191-
cur.execute(
192-
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
193-
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
194-
"cumulative_time_ns INTEGER, callers BLOB)"
195-
)
196-
for func, (cc, nc, tt, ct, callers) in self.stats.items():
197-
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
192+
with self._db_lock:
193+
if self.con is None:
194+
return
195+
196+
self.con.commit() # Commit any pending from tracer_logic
197+
console.rule("Codeflash: Traced Program Output End", style="bold blue")
198+
self.create_stats() # This calls snapshot_stats which uses self.timings
199+
200+
cur = self.con.cursor()
198201
cur.execute(
199-
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
200-
(str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
202+
"CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, "
203+
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
204+
"cumulative_time_ns INTEGER, callers BLOB)"
201205
)
202-
self.con.commit()
206+
# self.stats is populated by snapshot_stats() called within create_stats()
207+
# Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data
208+
# For now, assuming self.stats is primarily in-memory after create_stats()
209+
for func, (cc, nc, tt, ct, callers) in self.stats.items():
210+
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
211+
cur.execute(
212+
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
213+
(
214+
str(Path(func[0]).resolve()),
215+
func[1],
216+
func[2],
217+
func[3],
218+
cc,
219+
nc,
220+
tt,
221+
ct,
222+
json.dumps(remapped_callers),
223+
),
224+
)
225+
self.con.commit()
203226

204-
self.make_pstats_compatible()
205-
self.print_stats("tottime")
206-
cur = self.con.cursor()
207-
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
208-
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
209-
self.con.commit()
210-
self.con.close()
227+
self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory
228+
self.print_stats("tottime") # Uses self.stats, prints to console
229+
230+
cur = self.con.cursor() # New cursor
231+
cur.execute("CREATE TABLE total_time (time_ns INTEGER)")
232+
cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,))
233+
self.con.commit()
234+
self.con.close()
235+
self.con = None # Mark connection as closed
211236

212237
# filter any functions where we did not capture the return
213238
self.function_modules = [
@@ -244,16 +269,24 @@ def __exit__(
244269
overflow="ignore",
245270
)
246271

247-
def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
272+
def tracer_logic(self, frame: FrameType, event: str) -> None:
248273
if event != "call":
249274
return
250275
if None is not self.timeout and (time.time() - self.start_time) > self.timeout:
251276
sys.setprofile(None)
252277
threading.setprofile(None)
253278
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
254279
return
280+
if self.disable or self._db_lock is None or self.con is None:
281+
return
282+
255283
code = frame.f_code
256284

285+
# Check function name first before resolving path
286+
if code.co_name in self.ignored_functions:
287+
return
288+
289+
# Now resolve file path only if we need it
257290
co_filename = code.co_filename
258291
if co_filename in self.path_cache:
259292
file_name = self.path_cache[co_filename]
@@ -262,8 +295,6 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
262295
self.path_cache[co_filename] = file_name
263296
# TODO : It currently doesn't log the last return call from the first function
264297

265-
if code.co_name in self.ignored_functions:
266-
return
267298
if not file_name.is_relative_to(self.project_root):
268299
return
269300
if not file_name.exists():
@@ -290,7 +321,12 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
290321
except: # noqa: E722
291322
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
292323
return
293-
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
324+
325+
try:
326+
function_qualified_name = f"{file_name}:{code.co_qualname}"
327+
except AttributeError:
328+
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
329+
294330
if function_qualified_name in self.ignored_qualified_functions:
295331
return
296332
if function_qualified_name not in self.function_count:
@@ -323,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
323359

324360
# TODO: Also check if this function arguments are unique from the values logged earlier
325361

326-
cur = self.con.cursor()
362+
with self._db_lock:
363+
# Check connection again inside lock, in case __exit__ closed it.
364+
if self.con is None:
365+
return
327366

328-
t_ns = time.perf_counter_ns()
329-
original_recursion_limit = sys.getrecursionlimit()
330-
try:
331-
# pickling can be a recursive operator, so we need to increase the recursion limit
332-
sys.setrecursionlimit(10000)
333-
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
334-
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
335-
# leaks, bad references or side effects when unpickling.
336-
arguments = dict(arguments.items())
337-
if class_name and code.co_name == "__init__":
338-
del arguments["self"]
339-
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
340-
sys.setrecursionlimit(original_recursion_limit)
341-
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
342-
# we retry with dill if pickle fails. It's slower but more comprehensive
367+
cur = self.con.cursor()
368+
369+
t_ns = time.perf_counter_ns()
370+
original_recursion_limit = sys.getrecursionlimit()
343371
try:
344-
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
372+
# pickling can be a recursive operator, so we need to increase the recursion limit
373+
sys.setrecursionlimit(10000)
374+
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
375+
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
376+
# leaks, bad references or side effects when unpickling.
377+
arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals
378+
if class_name and code.co_name == "__init__" and "self" in arguments_copy:
379+
del arguments_copy["self"]
380+
local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL)
345381
sys.setrecursionlimit(original_recursion_limit)
382+
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
383+
# we retry with dill if pickle fails. It's slower but more comprehensive
384+
try:
385+
sys.setrecursionlimit(10000) # Ensure limit is high for dill too
386+
# arguments_copy should be used here as well if defined above
387+
local_vars = dill.dumps(
388+
arguments_copy if "arguments_copy" in locals() else dict(arguments.items()),
389+
protocol=dill.HIGHEST_PROTOCOL,
390+
)
391+
sys.setrecursionlimit(original_recursion_limit)
392+
393+
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
394+
self.function_count[function_qualified_name] -= 1
395+
return
346396

347-
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
348-
# give up
349-
self.function_count[function_qualified_name] -= 1
350-
return
351-
cur.execute(
352-
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
353-
(
354-
event,
355-
code.co_name,
356-
class_name,
357-
str(file_name),
358-
frame.f_lineno,
359-
frame.f_back.__hash__(),
360-
t_ns,
361-
local_vars,
362-
),
363-
)
364-
self.trace_count += 1
365-
self.next_insert -= 1
366-
if self.next_insert == 0:
367-
self.next_insert = 1000
368-
self.con.commit()
397+
cur.execute(
398+
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
399+
(
400+
event,
401+
code.co_name,
402+
class_name,
403+
str(file_name),
404+
frame.f_lineno,
405+
frame.f_back.__hash__(),
406+
t_ns,
407+
local_vars,
408+
),
409+
)
410+
self.trace_count += 1
411+
self.next_insert -= 1
412+
if self.next_insert == 0:
413+
self.next_insert = 1000
414+
self.con.commit()
369415

370416
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
371417
# profiler section
@@ -413,7 +459,7 @@ def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
413459
class_name = arguments["self"].__class__.__name__
414460
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
415461
class_name = arguments["cls"].__name__
416-
except Exception: # noqa: S110
462+
except Exception: # noqa: BLE001, S110
417463
pass
418464

419465
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
@@ -425,7 +471,7 @@ def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
425471
else:
426472
timings[fn] = 0, 0, 0, 0, {}
427473
return 1 # noqa: TRY300
428-
except Exception:
474+
except Exception: # noqa: BLE001
429475
# Handle any errors gracefully
430476
return 0
431477

@@ -488,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
488534
cc = cc + 1
489535

490536
if pfn in callers:
491-
callers[pfn] = callers[pfn] + 1 # TODO: gather more
492-
# stats such as the amount of time added to ct courtesy
537+
# Increment call count between these functions
538+
callers[pfn] = callers[pfn] + 1
539+
# Note: This tracks stats such as the amount of time added to ct
493540
# of this specific call, and the contribution to cc
494541
# courtesy of this call.
495542
else:
@@ -579,7 +626,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None:
579626

580627
# Store with new format
581628
new_stats[new_func] = (cc, nc, tt, ct, new_callers)
582-
except Exception as e:
629+
except Exception as e: # noqa: BLE001
583630
console.print(f"Error converting stats for {func}: {e}")
584631
continue
585632

@@ -616,7 +663,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None:
616663
new_callers[new_caller_func] = count
617664

618665
new_timings[new_func] = (cc, ns, tt, ct, new_callers)
619-
except Exception as e:
666+
except Exception as e: # noqa: BLE001
620667
console.print(f"Error converting timings for {func}: {e}")
621668
continue
622669

@@ -686,7 +733,7 @@ def print_stats(self, sort: str | int | tuple = -1) -> None:
686733

687734
console.print(Align.center(table))
688735

689-
except Exception as e:
736+
except Exception as e: # noqa: BLE001
690737
console.print(f"[bold red]Error in stats processing:[/bold red] {e}")
691738
console.print(f"Traced {self.trace_count:,} function calls")
692739
self.total_tt = 0
@@ -716,7 +763,7 @@ def create_stats(self) -> None:
716763

717764
def snapshot_stats(self) -> None:
718765
self.stats = {}
719-
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items():
766+
for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()):
720767
callers = caller_dict.copy()
721768
nc = 0
722769
for callcnt in callers.values():

0 commit comments

Comments
 (0)