Skip to content

Commit 2605e97

Browse files
committed
executemany
1 parent dbac51b commit 2605e97

File tree

1 file changed

+87
-59
lines changed

1 file changed

+87
-59
lines changed

codeflash/tracer.py

Lines changed: 87 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ class Tracer:
7272
Traces function calls, input arguments, and profiling info.
7373
"""
7474

75+
_path_cache: dict[str, bool] = {}
76+
_file_filter_cache: dict[str, bool] = {}
77+
78+
_function_call_buffer: list[dict[str, Any]] = []
79+
_buffer_size = 1000
80+
7581
def __init__(
7682
self,
7783
output: str = "codeflash.trace",
@@ -110,7 +116,7 @@ def __init__(
110116
return
111117
self.con = None
112118
self.output_file = Path(output).resolve()
113-
self.functions = functions
119+
self.functions = set(functions) # Convert to set for O(1) lookups
114120
self.function_modules: list[FunctionModules] = []
115121
self.function_count = defaultdict(int)
116122
self.current_file_path = Path(__file__).resolve()
@@ -124,11 +130,11 @@ def __init__(
124130
console.rule(f"Project Root: {self.project_root}", style="bold blue")
125131
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
126132

127-
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001
133+
self.file_being_called_from = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_")
128134

129135
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
130136
self.timeout = timeout
131-
self.next_insert = 1000
137+
self.next_insert = self._buffer_size
132138
self.trace_count = 0
133139

134140
# Profiler variables
@@ -142,6 +148,15 @@ def __init__(
142148
assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file"
143149
self.t = self.timer()
144150

151+
def _flush_buffer(self) -> None:
152+
if not self._function_call_buffer:
153+
return
154+
155+
cur = self.con.cursor()
156+
cur.executemany("INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", self._function_call_buffer)
157+
self.con.commit()
158+
self._function_call_buffer.clear()
159+
145160
def __enter__(self) -> None:
146161
if self.disable:
147162
return
@@ -161,15 +176,18 @@ def __enter__(self) -> None:
161176

162177
self.con = sqlite3.connect(self.output_file, check_same_thread=False)
163178
cur = self.con.cursor()
164-
cur.execute("""PRAGMA synchronous = OFF""")
165-
cur.execute("""PRAGMA journal_mode = WAL""")
166-
# TODO: Check out if we need to export the function test name as well
179+
cur.execute("PRAGMA synchronous = OFF")
180+
cur.execute("PRAGMA journal_mode = WAL")
181+
cur.execute("PRAGMA temp_store = MEMORY")
182+
cur.execute("PRAGMA cache_size = 10000")
183+
184+
# Create table
167185
cur.execute(
168186
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
169187
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
170188
)
171189
console.rule("Codeflash: Traced Program Output Begin", style="bold blue")
172-
frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001
190+
frame = sys._getframe(0)
173191
self.dispatch["call"](self, frame, 0)
174192
self.start_time = time.time()
175193
sys.setprofile(self.trace_callback)
@@ -181,7 +199,10 @@ def __exit__(
181199
if self.disable:
182200
return
183201
sys.setprofile(None)
184-
self.con.commit()
202+
203+
# Flush any remaining buffered function calls
204+
self._flush_buffer()
205+
185206
console.rule("Codeflash: Traced Program Output End", style="bold blue")
186207
self.create_stats()
187208

@@ -191,12 +212,14 @@ def __exit__(
191212
"call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, "
192213
"cumulative_time_ns INTEGER, callers BLOB)"
193214
)
215+
pstats_data = []
194216
for func, (cc, nc, tt, ct, callers) in self.stats.items():
195217
remapped_callers = [{"key": k, "value": v} for k, v in callers.items()]
196-
cur.execute(
197-
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
198-
(str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)),
218+
pstats_data.append(
219+
(str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers))
199220
)
221+
222+
cur.executemany("INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", pstats_data)
200223
self.con.commit()
201224

202225
self.make_pstats_compatible()
@@ -207,7 +230,7 @@ def __exit__(
207230
self.con.commit()
208231
self.con.close()
209232

210-
# filter any functions where we did not capture the return
233+
# Filter any functions where we did not capture the return
211234
self.function_modules = [
212235
function
213236
for function in self.function_modules
@@ -254,45 +277,62 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
254277
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
255278
return
256279
code = frame.f_code
257-
file_name = Path(code.co_filename).resolve()
258280
# TODO : It currently doesn't log the last return call from the first function
259281

260282
if code.co_name in self.ignored_functions:
261283
return
284+
285+
file_name_str = code.co_filename
286+
if file_name_str not in self._path_cache:
287+
file_name = Path(file_name_str).resolve()
288+
self._path_cache[file_name_str] = file_name
289+
else:
290+
file_name = self._path_cache[file_name_str]
291+
262292
if not file_name.exists():
263293
return
294+
264295
if self.functions and code.co_name not in self.functions:
265296
return
266-
class_name = None
297+
298+
# Extract class information
267299
arguments = frame.f_locals
300+
class_name = None
268301
try:
269-
if (
270-
"self" in arguments
271-
and hasattr(arguments["self"], "__class__")
272-
and hasattr(arguments["self"].__class__, "__name__")
273-
):
274-
class_name = arguments["self"].__class__.__name__
302+
if "self" in arguments and hasattr(arguments["self"], "__class__"):
303+
cls = arguments["self"].__class__
304+
if hasattr(cls, "__name__"):
305+
class_name = cls.__name__
275306
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
276307
class_name = arguments["cls"].__name__
277308
except: # noqa: E722
278309
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
279310
return
311+
280312
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
281313
if function_qualified_name in self.ignored_qualified_functions:
282314
return
283315
if function_qualified_name not in self.function_count:
284316
# seeing this function for the first time
285317
self.function_count[function_qualified_name] = 0
286-
file_valid = filter_files_optimized(
287-
file_path=file_name,
288-
tests_root=Path(self.config["tests_root"]),
289-
ignore_paths=[Path(p) for p in self.config["ignore_paths"]],
290-
module_root=Path(self.config["module_root"]),
291-
)
318+
319+
# Cache file filter results
320+
if file_name not in self._file_filter_cache:
321+
file_valid = filter_files_optimized(
322+
file_path=file_name,
323+
tests_root=Path(self.config["tests_root"]),
324+
ignore_paths=[Path(p) for p in self.config["ignore_paths"]],
325+
module_root=Path(self.config["module_root"]),
326+
)
327+
self._file_filter_cache[file_name] = file_valid
328+
else:
329+
file_valid = self._file_filter_cache[file_name]
330+
292331
if not file_valid:
293332
# we don't want to trace this function because it cannot be optimized
294333
self.ignored_qualified_functions.add(function_qualified_name)
295334
return
335+
296336
self.function_modules.append(
297337
FunctionModules(
298338
function_name=code.co_name,
@@ -308,63 +348,51 @@ def tracer_logic(self, frame: FrameType, event: str) -> None:
308348
self.ignored_qualified_functions.add(function_qualified_name)
309349
return
310350

311-
# TODO: Also check if this function arguments are unique from the values logged earlier
312-
313-
cur = self.con.cursor()
314-
351+
# Serialize function arguments
315352
t_ns = time.perf_counter_ns()
316353
original_recursion_limit = sys.getrecursionlimit()
317354
try:
318-
# pickling can be a recursive operator, so we need to increase the recursion limit
319-
sys.setrecursionlimit(10000)
320-
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
321-
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
322-
# leaks, bad references or side effects when unpickling.
323-
arguments = dict(arguments.items())
355+
# Make a copy to avoid side effects
356+
arguments_copy = dict(arguments.items())
324357
if class_name and code.co_name == "__init__":
325-
del arguments["self"]
326-
local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL)
358+
del arguments_copy["self"]
359+
360+
# Use protocol 5 for better performance with Python 3.8+
361+
local_vars = pickle.dumps(arguments_copy, protocol=5)
327362
sys.setrecursionlimit(original_recursion_limit)
328363
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
329-
# we retry with dill if pickle fails. It's slower but more comprehensive
330364
try:
331-
local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL)
365+
local_vars = dill.dumps(arguments_copy, protocol=dill.HIGHEST_PROTOCOL)
332366
sys.setrecursionlimit(original_recursion_limit)
333-
334367
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError):
335368
# give up
336369
self.function_count[function_qualified_name] -= 1
337370
return
338-
cur.execute(
339-
"INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)",
340-
(
341-
event,
342-
code.co_name,
343-
class_name,
344-
str(file_name),
345-
frame.f_lineno,
346-
frame.f_back.__hash__(),
347-
t_ns,
348-
local_vars,
349-
),
371+
372+
# Add to buffer instead of immediate DB insertion
373+
self._function_call_buffer.append(
374+
(event, code.co_name, class_name, str(file_name), frame.f_lineno, frame.f_back.__hash__(), t_ns, local_vars)
350375
)
376+
351377
self.trace_count += 1
352378
self.next_insert -= 1
353379
if self.next_insert == 0:
354-
self.next_insert = 1000
355-
self.con.commit()
380+
self._flush_buffer()
381+
self.next_insert = self._buffer_size
356382

357383
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
358-
# profiler section
384+
# Profiler section
359385
timer = self.timer
360386
t = timer() - self.t - self.bias
361387
if event == "c_call":
362388
self.c_func_name = arg.__name__
363389

364390
prof_success = bool(self.dispatch[event](self, frame, t))
365-
# tracer section
366-
self.tracer_logic(frame, event)
367-
# measure the time as the last thing before return
391+
392+
# Only process 'call' events for tracing to reduce overhead
393+
if event == "call":
394+
self.tracer_logic(frame, event)
395+
368396
if prof_success:
369397
self.t = timer()
370398
else:

0 commit comments

Comments
 (0)