From 673a6384e6ef716a5f5df2b6144bdebd97e2b7bf Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 9 Mar 2025 01:40:53 -0800 Subject: [PATCH 01/13] add threading trace_callback --- codeflash/tracer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 96c0202f1..15b8f5d95 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -26,6 +26,7 @@ from copy import copy from io import StringIO from pathlib import Path +import threading from types import FrameType from typing import Any, ClassVar, List @@ -143,11 +144,13 @@ def __enter__(self) -> None: self.dispatch["call"](self, frame, 0) self.start_time = time.time() sys.setprofile(self.trace_callback) + threading.setprofile(self.trace_callback) def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.disable: return sys.setprofile(None) + threading.setprofile(None) self.con.commit() self.create_stats() From 6a232046c8bb285baa1613091f2d23433dcf2385 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 9 Mar 2025 01:42:04 -0800 Subject: [PATCH 02/13] Create testbench.py --- testbench.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 testbench.py diff --git a/testbench.py b/testbench.py new file mode 100644 index 000000000..e3856033d --- /dev/null +++ b/testbench.py @@ -0,0 +1,54 @@ +from concurrent.futures import ThreadPoolExecutor + + +def add_numbers(a: int, b: int) -> int: + print(f"[ADD_NUMBERS] Starting with parameters: a={a}, b={b}") + result = a + b + print(f"[ADD_NUMBERS] Returning result: {result}") + return result + + +def test_threadpool() -> None: + print("[TEST_THREADPOOL] Starting thread pool execution") + pool = ThreadPoolExecutor(max_workers=3) + numbers = [(10, 20), (30, 40), (50, 60)] + print("[TEST_THREADPOOL] Submitting tasks to thread pool") + result = pool.map(add_numbers, *zip(*numbers)) + + print("[TEST_THREADPOOL] Processing results") + for r in result: + print(f"[TEST_THREADPOOL] Thread result: {r}") + print("[TEST_THREADPOOL] Finished thread pool execution") + + +def multiply_numbers(a: int, b: int) -> int: + print(f"[MULTIPLY_NUMBERS] Starting with parameters: a={a}, b={b}") + result = a * b + print(f"[MULTIPLY_NUMBERS] Returning result: {result}") + return result + + +if __name__ == "__main__": + print("[MAIN] Starting testbench execution") + + print("[MAIN] Calling test_threadpool()") + test_threadpool() + print("[MAIN] Finished test_threadpool()") + + print("[MAIN] Calling add_numbers(5, 10)") + result1 = add_numbers(5, 10) + print(f"[MAIN] add_numbers result: {result1}") + + print("[MAIN] Calling add_numbers(15, 25)") + result2 = add_numbers(15, 25) + print(f"[MAIN] add_numbers result: {result2}") + + print("[MAIN] Calling multiply_numbers(3, 7)") + result3 = multiply_numbers(3, 7) + print(f"[MAIN] multiply_numbers result: {result3}") + + print("[MAIN] Calling multiply_numbers(5, 9)") + result4 = multiply_numbers(5, 9) + print(f"[MAIN] multiply_numbers result: {result4}") + + print("[MAIN] Testbench execution completed") From 85b6c73910db4eee3701428dfb59ef31b209a3e9 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 9 Mar 2025 21:59:11 -0700 Subject: [PATCH 03/13] db cleanup --- codeflash/tracer.py | 867 +++++++++++++++++++++++++++++++------------- 1 file changed, 620 insertions(+), 247 deletions(-) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 15b8f5d95..1bab6f0e6 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -16,39 +16,66 @@ import json import marshal import os -import pathlib import pickle -import re +import shutil import sqlite3 import sys +import tempfile +import threading import time +import uuid +from argparse import ArgumentParser from collections import defaultdict -from copy import copy -from io import StringIO from pathlib import Path -import threading -from types import FrameType -from typing import Any, ClassVar, List +from typing import TYPE_CHECKING, Any, Callable, ClassVar import dill import isort +from rich.align import Align +from rich.panel import Panel +from rich.table import Table +from rich.text import Text from codeflash.cli_cmds.cli import project_root_from_module_root from codeflash.cli_cmds.console import console -from codeflash.code_utils.code_utils import module_name_from_file_path +from codeflash.code_utils.code_utils import cleanup_paths, module_name_from_file_path from codeflash.code_utils.config_parser import parse_config_file from codeflash.discovery.functions_to_optimize import filter_files_optimized from codeflash.tracing.replay_test import create_trace_replay_test from codeflash.tracing.tracing_utils import FunctionModules from codeflash.verification.verification_utils import get_test_file_path +if TYPE_CHECKING: + from types import FrameType, TracebackType + + +class fake_code: # noqa: N801 + def __init__(self, filename: str, line: int, name: str) -> None: + self.co_filename = filename + self.co_line = line + self.co_name = name + self.co_firstlineno = 0 + + def __repr__(self) -> str: + return repr((self.co_filename, self.co_line, self.co_name, None)) + + +class fake_frame: # noqa: N801 + def __init__(self, code: fake_code, prior: fake_frame | None) -> None: + self.f_code = code + self.f_back = prior + self.f_locals = {} + # Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. class Tracer: - """Use this class as a 'with' context manager to trace a function call, - input arguments, and profiling info. + """Use this class as a 'with' context manager to trace a function call. + + Traces function calls, input arguments, and profiling info. """ + used_once: ClassVar[bool] = False # Class variable to track if Tracer has been used + def __init__( self, output: str = "codeflash.trace", @@ -58,14 +85,7 @@ def __init__( max_function_count: int = 256, timeout: int | None = None, # seconds ) -> None: - """:param output: The path to the output trace file - :param functions: List of functions to trace. If None, trace all functions - :param disable: Disable the tracer if True - :param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered - :param max_function_count: Maximum number of times to trace one function - :param timeout: Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing - stops and normal execution continues. If this is None then no timeout applies - """ + """Initialize Tracer.""" if functions is None: functions = [] if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": @@ -81,8 +101,14 @@ def __init__( ) self.disable = True return - self.con = None + + # Setup output paths self.output_file = Path(output).resolve() + self.output_dir = self.output_file.parent + self.output_base = self.output_file.stem + self.output_ext = self.output_file.suffix + self.thread_db_files: dict[int, Path] = {} # Thread ID to DB file path + self.functions = functions self.function_modules: list[FunctionModules] = [] self.function_count = defaultdict(int) @@ -94,90 +120,321 @@ def __init__( self.max_function_count = max_function_count self.config, found_config_path = parse_config_file(config_file_path) self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path) - print("project_root", self.project_root) + console.print("project_root", self.project_root) self.ignored_functions = {"", "", "", "", "", ""} - self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") + self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001 assert timeout is None or timeout > 0, "Timeout should be greater than 0" self.timeout = timeout self.next_insert = 1000 self.trace_count = 0 + self.db_lock = threading.RLock() + self.thread_local = threading.local() + + # Ensure output directory exists + self.output_dir.mkdir(parents=True, exist_ok=True) + # Profiler variables self.bias = 0 # calibration constant - self.timings = {} - self.cur = None - self.start_time = None + self.timings: dict[Any, Any] = {} + self.cur: Any = None + self.start_time: float | None = None self.timer = time.process_time_ns self.total_tt = 0 self.simulate_call("profiler") assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file" self.t = self.timer() + self.main_db_created = False # Flag to track main DB creation + + def get_thread_db_path(self) -> Path: + """Get the database path for the current thread.""" + thread_id = threading.get_ident() + if thread_id not in self.thread_db_files: + # Create a unique filename for this thread + unique_id = uuid.uuid4().hex[:8] + db_path = self.output_dir / f"{self.output_base}_{thread_id}_{unique_id}{self.output_ext}" + self.thread_db_files[thread_id] = db_path + return self.thread_db_files[thread_id] + + def get_connection(self) -> sqlite3.Connection: + """Get a dedicated connection for the current thread.""" + if not hasattr(self.thread_local, "con"): + db_path = self.get_thread_db_path() + self.thread_local.con = sqlite3.connect(db_path) + # Create the necessary tables if they don't exist + cur = self.thread_local.con.cursor() + cur.execute("""PRAGMA synchronous = OFF""") + cur.execute( + "CREATE TABLE IF NOT EXISTS function_calls(type TEXT, function TEXT, classname TEXT, " + "filename TEXT, line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" + ) + self.thread_local.con.commit() + return self.thread_local.con + + def _create_main_db_if_not_exists(self) -> None: + """Create the main output database if it doesn't exist.""" + if not self.main_db_created: # Use flag to prevent redundant checks + if not self.output_file.exists(): + try: + main_con = sqlite3.connect(self.output_file) + main_cur = main_con.cursor() + main_cur.execute("""PRAGMA synchronous = OFF""") # Added pragma for main db too + main_cur.execute( + "CREATE TABLE IF NOT EXISTS function_calls(type TEXT, function TEXT, classname TEXT, " + "filename TEXT, line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" + ) + main_con.commit() + main_con.close() + self.main_db_created = True # Set flag after successful creation + except Exception as e: # noqa: BLE001 + console.print(f"Error creating main database: {e}") + else: + self.main_db_created = True # Main DB already exists def __enter__(self) -> None: if self.disable: return - if getattr(Tracer, "used_once", False): + if Tracer.used_once: console.print( "Codeflash: Tracer can only be used once per program run. " "Please only enable the Tracer once. Skipping tracing this section." ) self.disable = True return - Tracer.used_once = True + Tracer.used_once = True # Mark Tracer as used at the start of __enter__ - if pathlib.Path(self.output_file).exists(): + # Clean up any existing trace files + if self.output_file.exists(): console.print("Codeflash: Removing existing trace file") - pathlib.Path(self.output_file).unlink(missing_ok=True) - - self.con = sqlite3.connect(self.output_file) - cur = self.con.cursor() - cur.execute("""PRAGMA synchronous = OFF""") - # TODO: Check out if we need to export the function test name as well - cur.execute( - "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " - "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" - ) + cleanup_paths([self.output_file]) + + self._create_main_db_if_not_exists() + self.con = sqlite3.connect(self.output_file) # Keep connection open during tracing console.print("Codeflash: Tracing started!") - frame = sys._getframe(0) # Get this frame and simulate a call to it + console.rule("Program Output Begin", style="bold blue") + frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 self.dispatch["call"](self, frame, 0) self.start_time = time.time() sys.setprofile(self.trace_callback) threading.setprofile(self.trace_callback) - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def _close_thread_connection(self) -> None: + """Close thread-local connection and handle potential errors.""" + if hasattr(self.thread_local, "con") and self.thread_local.con: + try: + self.thread_local.con.commit() + self.thread_local.con.close() + del self.thread_local.con + except Exception as e: # noqa: BLE001 + console.print(f"Error closing current thread's connection: {e}") + + def _merge_thread_dbs(self) -> int: + total_rows_copied = 0 + processed_files: list[Path] = [] + + for thread_id, db_path in self.thread_db_files.items(): + if not db_path.exists(): + console.print(f"Thread database for thread {thread_id} not found, skipping.") + continue + + rows_copied = self._process_thread_db(thread_id, db_path) + if rows_copied >= 0: # _process_thread_db returns -1 on failure + total_rows_copied += rows_copied + processed_files.append(db_path) + else: + console.print(f"Failed to merge from thread database {thread_id}") + + for thread_id, db_path in self.thread_db_files.items(): + if db_path in processed_files or not db_path.exists(): + continue + + rows_copied = self._process_thread_db_with_copy(thread_id, db_path) + if rows_copied >= 0: + total_rows_copied += rows_copied + processed_files.append(db_path) + else: + console.print(f"Failed to merge from thread database {thread_id} even with copy approach.") + + return total_rows_copied + + def _process_thread_db(self, thread_id: int, db_path: Path) -> int: + try: + thread_con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=2.0) + thread_cur = thread_con.cursor() + + thread_cur.execute("SELECT * FROM function_calls") + main_cur = self.con.cursor() + + self.con.execute("BEGIN TRANSACTION") + + batch_size = 100 + batch = thread_cur.fetchmany(batch_size) + rows_processed = 0 + + while batch: + for row in batch: + try: + main_cur.execute("INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", row) + rows_processed += 1 + except sqlite3.Error as e: # noqa: PERF203 + console.print(f"Error inserting row {rows_processed} from thread {thread_id}: {e}") + batch = thread_cur.fetchmany(batch_size) + + self.con.commit() + thread_con.close() + + except sqlite3.Error as e: + console.print(f"Could not open thread database {thread_id} directly: {e}") + return -1 + else: + return rows_processed + + def _process_thread_db_with_copy(self, thread_id: int, db_path: Path) -> int: + console.print(f"Attempting file copy approach for thread {thread_id}...") + + temp_dir = tempfile.gettempdir() + temp_db_path = Path(temp_dir) / f"codeflash_temp_{uuid.uuid4().hex}.trace" + rows_processed = 0 + + try: + shutil.copy2(db_path, temp_db_path) + + temp_con = sqlite3.connect(temp_db_path) + temp_cur = temp_con.cursor() + + temp_cur.execute("SELECT COUNT(*) FROM function_calls") + row_count = temp_cur.fetchone()[0] + + if row_count > 0: + temp_cur.execute("SELECT * FROM function_calls") + main_cur = self.con.cursor() + + self.con.execute("BEGIN TRANSACTION") + batch_size = 100 + batch = temp_cur.fetchmany(batch_size) + + while batch: + for row in batch: + try: + main_cur.execute("INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", row) + rows_processed += 1 + except sqlite3.Error as e: + console.print(f"Error inserting row from thread {thread_id} copy: {e}") + + batch = temp_cur.fetchmany(batch_size) + + self.con.commit() + + temp_con.close() + cleanup_paths([temp_db_path]) + console.print(f"Successfully merged {rows_processed} rows from thread {thread_id} (via copy)") + except Exception as e: # noqa: BLE001 + console.print(f"Error with file copy approach for thread {thread_id}: {e}") + cleanup_paths([temp_db_path]) + return -1 + + else: + return rows_processed + + def _generate_stats_and_replay_test(self) -> None: + """Generate statistics, pstats compatible data, print stats and create replay test.""" + try: + self.create_stats() + + try: + main_cur = self.con.cursor() + main_cur.execute( + "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " + "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " + "cumulative_time_ns INTEGER, callers BLOB)" + ) + + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] + main_cur.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + str(Path(func[0]).resolve()), + func[1], + func[2], + func[3], + cc, + nc, + tt, + ct, + json.dumps(remapped_callers), + ), + ) + + self.con.commit() # Use main DB connection + + self.make_pstats_compatible() + self.print_stats("tottime") + + main_cur.execute("CREATE TABLE total_time (time_ns INTEGER)") + main_cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) + self.con.commit() # Use main DB connection + + except Exception as e: # noqa: BLE001 + console.print(f"Error generating stats tables: {e}") + import traceback + + traceback.print_exc() + + except Exception as e: # noqa: BLE001 + console.print(f"Error during stats generation: {e}") + console.print_exception() + + # Generate the replay test + try: + replay_test = create_trace_replay_test( + trace_file=self.output_file, + functions=self.function_modules, + test_framework=self.config["test_framework"], + max_run_count=self.max_function_count, + ) + function_path = "_".join(self.functions) if self.functions else self.file_being_called_from + test_file_path = get_test_file_path( + test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" + ) + replay_test = isort.code(replay_test) + + with test_file_path.open("w", encoding="utf8") as file: + file.write(replay_test) + + console.print( + f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}", + crop=False, + soft_wrap=False, + overflow="ignore", + ) + except Exception as e: # noqa: BLE001 + console.print(f"Error creating replay test: {e}") + console.print_exception() + + def __exit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> None: if self.disable: return + console.rule("Program Output End", style="bold blue") sys.setprofile(None) threading.setprofile(None) - self.con.commit() - - self.create_stats() - - cur = self.con.cursor() - cur.execute( - "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " - "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " - "cumulative_time_ns INTEGER, callers BLOB)" - ) - for func, (cc, nc, tt, ct, callers) in self.stats.items(): - remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] - cur.execute( - "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - (str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)), - ) - self.con.commit() - self.make_pstats_compatible() - self.print_stats("tottime") - cur = self.con.cursor() - cur.execute("CREATE TABLE total_time (time_ns INTEGER)") - cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) - self.con.commit() - self.con.close() + self._close_thread_connection() + + # Give threads time to complete their database operations + time.sleep(1) + + self._merge_thread_dbs() + self._generate_stats_and_replay_test() - # filter any functions where we did not capture the return + all_db_paths = list(self.thread_db_files.values()) + cleanup_paths(all_db_paths) + + # Filter any functions where we did not capture the return - moved to replay test generation for clarity self.function_modules = [ function for function in self.function_modules @@ -190,26 +447,9 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: > 0 ] - replay_test = create_trace_replay_test( - trace_file=self.output_file, - functions=self.function_modules, - test_framework=self.config["test_framework"], - max_run_count=self.max_function_count, - ) - function_path = "_".join(self.functions) if self.functions else self.file_being_called_from - test_file_path = get_test_file_path( - test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" - ) - replay_test = isort.code(replay_test) - with open(test_file_path, "w", encoding="utf8") as file: - file.write(replay_test) - - console.print( - f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}", - crop=False, - soft_wrap=False, - overflow="ignore", - ) + if self.con: + self.con.close() + self.con = None def tracer_logic(self, frame: FrameType, event: str) -> None: if event != "call": @@ -239,9 +479,10 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: class_name = arguments["self"].__class__.__name__ elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): class_name = arguments["cls"].__name__ - except: + except: # noqa: E722 # someone can override the getattr method and raise an exception. I'm looking at you wrapt return + function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" if function_qualified_name in self.ignored_qualified_functions: return @@ -273,9 +514,9 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: self.ignored_qualified_functions.add(function_qualified_name) return - # TODO: Also check if this function arguments are unique from the values logged earlier - - cur = self.con.cursor() + # Get thread-specific connection + conn = self.get_connection() + cur = conn.cursor() t_ns = time.perf_counter_ns() original_recursion_limit = sys.getrecursionlimit() @@ -300,26 +541,34 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # give up self.function_count[function_qualified_name] -= 1 return - cur.execute( - "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", - ( - event, - code.co_name, - class_name, - str(file_name), - frame.f_lineno, - frame.f_back.__hash__(), - t_ns, - local_vars, - ), - ) - self.trace_count += 1 - self.next_insert -= 1 - if self.next_insert == 0: - self.next_insert = 1000 - self.con.commit() + try: + cur.execute( + "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", + ( + event, + code.co_name, + class_name, + str(file_name), + frame.f_lineno, + frame.f_back.__hash__(), + t_ns, + local_vars, + ), + ) - def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None: + # Add thread-safe counter increment for trace_count + with self.db_lock: + self.trace_count += 1 + + self.next_insert -= 1 + if self.next_insert == 0: + self.next_insert = 1000 + conn.commit() + except sqlite3.Error as e: + thread_id = threading.get_ident() + console.print(f"SQLite error in tracer (thread {thread_id}): {e}") + + def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: # profiler section timer = self.timer t = timer() - self.t - self.bias @@ -335,45 +584,60 @@ def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None: else: self.t = timer() - t # put back unrecorded delta - def trace_dispatch_call(self, frame, t) -> int: - if self.cur and frame.f_back is not self.cur[-2]: - rpt, rit, ret, rfn, rframe, rcur = self.cur - if not isinstance(rframe, Tracer.fake_frame): - assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back) - self.trace_dispatch_return(rframe, 0) - assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3]) - fcode = frame.f_code - arguments = frame.f_locals - class_name = None + def trace_dispatch_call(self, frame: FrameType, t: int) -> int: + """Handle call events in the profiler.""" try: - if ( - "self" in arguments - and hasattr(arguments["self"], "__class__") - and hasattr(arguments["self"].__class__, "__name__") - ): - class_name = arguments["self"].__class__.__name__ - elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): - class_name = arguments["cls"].__name__ - except: - pass - fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name) - self.cur = (t, 0, 0, fn, frame, self.cur) - timings = self.timings - if fn in timings: - cc, ns, tt, ct, callers = timings[fn] - timings[fn] = cc, ns + 1, tt, ct, callers - else: - timings[fn] = 0, 0, 0, 0, {} - return 1 + # In multi-threaded contexts, we need to be more careful about frame comparisons + if self.cur and frame.f_back is not self.cur[-2]: + # This happens when we're in a different thread + rpt, rit, ret, rfn, rframe, rcur = self.cur + + # Only attempt to handle the frame mismatch if we have a valid rframe + if ( + not isinstance(rframe, Tracer.fake_frame) + and hasattr(rframe, "f_back") + and hasattr(frame, "f_back") + and rframe.f_back is frame.f_back + ): + self.trace_dispatch_return(rframe, 0) + + # Get function information + fcode = frame.f_code + arguments = frame.f_locals + class_name = None + try: + if ( + "self" in arguments + and hasattr(arguments["self"], "__class__") + and hasattr(arguments["self"].__class__, "__name__") + ): + class_name = arguments["self"].__class__.__name__ + elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): + class_name = arguments["cls"].__name__ + except Exception: # noqa: BLE001, S110 + pass + + fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns + 1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 # noqa: TRY300 + except Exception: # noqa: BLE001 + # Handle any errors gracefully + return 0 - def trace_dispatch_exception(self, frame, t): + def trace_dispatch_exception(self, frame: FrameType, t: int) -> int: rpt, rit, ret, rfn, rframe, rcur = self.cur if (rframe is not frame) and rcur: return self.trace_dispatch_return(rframe, t) self.cur = rpt, rit + t, ret, rfn, rframe, rcur return 1 - def trace_dispatch_c_call(self, frame, t) -> int: + def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int: fn = ("", 0, self.c_func_name, None) self.cur = (t, 0, 0, fn, frame, self.cur) timings = self.timings @@ -384,44 +648,57 @@ def trace_dispatch_c_call(self, frame, t) -> int: timings[fn] = 0, 0, 0, 0, {} return 1 - def trace_dispatch_return(self, frame, t) -> int: - if frame is not self.cur[-2]: - assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3]) - self.trace_dispatch_return(self.cur[-2], 0) - - # Prefix "r" means part of the Returning or exiting frame. - # Prefix "p" means part of the Previous or Parent or older frame. - - rpt, rit, ret, rfn, frame, rcur = self.cur - rit = rit + t - frame_total = rit + ret - - ppt, pit, pet, pfn, pframe, pcur = rcur - self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur - - timings = self.timings - cc, ns, tt, ct, callers = timings[rfn] - if not ns: - # This is the only occurrence of the function on the stack. - # Else this is a (directly or indirectly) recursive call, and - # its cumulative time will get updated when the topmost call to - # it returns. - ct = ct + frame_total - cc = cc + 1 - - if pfn in callers: - callers[pfn] = callers[pfn] + 1 # hack: gather more - # stats such as the amount of time added to ct courtesy - # of this specific call, and the contribution to cc - # courtesy of this call. - else: - callers[pfn] = 1 + def trace_dispatch_return(self, frame: FrameType, t: int) -> int: + """Handle return events in the profiler.""" + try: + # Check if we have a valid current frame + if not self.cur or not self.cur[-2]: + return 0 + + # In multi-threaded environments, frames can get mismatched + if frame is not self.cur[-2]: + # Don't assert in threaded environments - frames can legitimately differ + if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: + self.trace_dispatch_return(self.cur[-2], 0) + else: + # We're in a different thread or context, can't continue with this frame + return 0 + + rpt, rit, ret, rfn, frame, rcur = self.cur + rit = rit + t + frame_total = rit + ret + + # Guard against invalid rcur (w threading) + if not rcur: + return 0 + + ppt, pit, pet, pfn, pframe, pcur = rcur + self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur + + timings = self.timings + if rfn not in timings: + # w threading, rfn can be missing + timings[rfn] = 0, 0, 0, 0, {} + + cc, ns, tt, ct, callers = timings[rfn] + if not ns: + # This is the only occurrence of the function on the stack. + ct = ct + frame_total + cc = cc + 1 + + if pfn in callers: + callers[pfn] = callers[pfn] + 1 + else: + callers[pfn] = 1 - timings[rfn] = cc, ns - 1, tt + rit, ct, callers + timings[rfn] = cc, ns - 1, tt + rit, ct, callers - return 1 + return 1 + except Exception: + # Handle errors gracefully + return 0 - dispatch: ClassVar[dict[str, callable]] = { + dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { "call": trace_dispatch_call, "exception": trace_dispatch_exception, "return": trace_dispatch_return, @@ -430,26 +707,10 @@ def trace_dispatch_return(self, frame, t) -> int: "c_return": trace_dispatch_return, } - class fake_code: - def __init__(self, filename, line, name) -> None: - self.co_filename = filename - self.co_line = line - self.co_name = name - self.co_firstlineno = 0 - - def __repr__(self) -> str: - return repr((self.co_filename, self.co_line, self.co_name, None)) - - class fake_frame: - def __init__(self, code, prior) -> None: - self.f_code = code - self.f_back = prior - self.f_locals = {} - def simulate_call(self, name) -> None: - code = self.fake_code("profiler", 0, name) + code = fake_code("profiler", 0, name) pframe = self.cur[-2] if self.cur else None - frame = self.fake_frame(code, pframe) + frame = fake_frame(code, pframe) self.dispatch["call"](self, frame, 0) def simulate_cmd_complete(self) -> None: @@ -462,58 +723,172 @@ def simulate_cmd_complete(self) -> None: t = 0 self.t = get_time() - t - def print_stats(self, sort=-1) -> None: - import pstats + def print_stats(self, sort: str | int | tuple = -1) -> None: + if not self.stats: + console.print("Codeflash: No stats available to print") + self.total_tt = 0 + return if not isinstance(sort, tuple): sort = (sort,) - # The following code customizes the default printing behavior to - # print in milliseconds. - s = StringIO() - stats_obj = pstats.Stats(copy(self), stream=s) - stats_obj.strip_dirs().sort_stats(*sort).print_stats(25) - self.total_tt = stats_obj.total_tt - console.print("total_tt", self.total_tt) - raw_stats = s.getvalue() - m = re.search(r"function calls?.*in (\d+)\.\d+ (seconds?)", raw_stats) - total_time = None - if m: - total_time = int(m.group(1)) - if total_time is None: - console.print("Failed to get total time from stats") - total_time_ms = total_time / 1e6 - raw_stats = re.sub( - r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats - ) - match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +" - m = re.findall(match_pattern, raw_stats, re.MULTILINE) - ms_times = [] - for tottime, percall, cumtime, percall_cum in m: - tottime_ms = int(tottime) / 1e6 - percall_ms = int(percall) / 1e6 - cumtime_ms = int(cumtime) / 1e6 - percall_cum_ms = int(percall_cum) / 1e6 - ms_times.append([tottime_ms, percall_ms, cumtime_ms, percall_cum_ms]) - split_stats = raw_stats.split("\n") - new_stats = [] - - replace_pattern = r"^( *[\d\/]+) +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(.*)" - times_index = 0 - for line in split_stats: - if times_index >= len(ms_times): - replaced = line - else: - replaced, n = re.subn( - replace_pattern, - rf"\g<1>{ms_times[times_index][0]:8.3f} {ms_times[times_index][1]:8.3f} {ms_times[times_index][2]:8.3f} {ms_times[times_index][3]:8.3f} \g<6>", - line, - count=1, + + # First, convert stats to make them pstats-compatible + try: + # Initialize empty collections for pstats + self.files = [] + self.top_level = [] + + # Create entirely new dictionaries instead of modifying existing ones + new_stats = {} + new_timings = {} + + # Convert stats dictionary + stats_items = list(self.stats.items()) + for func, stats_data in stats_items: + try: + # Make sure we have 5 elements in stats_data + if len(stats_data) != 5: + console.print(f"Skipping malformed stats data for {func}: {stats_data}") + continue + + cc, nc, tt, ct, callers = stats_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func # Keep as is if already in correct format + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + console.print(f"Unexpected caller format: {caller_func}") + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + # Store with new format + new_stats[new_func] = (cc, nc, tt, ct, new_callers) + except Exception as e: # noqa: BLE001 + console.print(f"Error converting stats for {func}: {e}") + continue + + timings_items = list(self.timings.items()) + for func, timing_data in timings_items: + try: + if len(timing_data) != 5: + console.print(f"Skipping malformed timing data for {func}: {timing_data}") + continue + + cc, ns, tt, ct, callers = timing_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + console.print(f"Unexpected caller format: {caller_func}") + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + new_timings[new_func] = (cc, ns, tt, ct, new_callers) + except Exception as e: # noqa: BLE001 + console.print(f"Error converting timings for {func}: {e}") + continue + + self.stats = new_stats + self.timings = new_timings + + self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values()) + + total_calls = sum(cc for cc, _, _, _, _ in self.stats.values()) + total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values()) + + summary = Text.assemble( + f"{total_calls:,} function calls ", + ("(" + f"{total_primitive:,} primitive calls" + ")", "dim"), + f" in {self.total_tt / 1e6:.3f}milliseconds", + ) + + console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False))) + + table = Table( + show_header=True, + header_style="bold magenta", + border_style="blue", + title="[bold]Function Profile[/bold] (ordered by internal time)", + title_style="cyan", + caption=f"Showing top 25 of {len(self.stats)} functions", + ) + + table.add_column("Calls", justify="right", style="green", width=10) + table.add_column("Time (ms)", justify="right", style="cyan", width=10) + table.add_column("Per Call", justify="right", style="cyan", width=10) + table.add_column("Cum (ms)", justify="right", style="yellow", width=10) + table.add_column("Cum/Call", justify="right", style="yellow", width=10) + table.add_column("Function", style="blue") + + sorted_stats = sorted( + ((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3), + key=lambda x: x[1][2], # Sort by tt (internal time) + reverse=True, + )[:25] # Limit to top 25 + + # Format and add each row to the table + for func, (cc, nc, tt, ct, _) in sorted_stats: + filename, lineno, funcname = func + + # Format calls - show recursive format if different + calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}" + + # Convert to milliseconds + tt_ms = tt / 1e6 + ct_ms = ct / 1e6 + + # Calculate per-call times + per_call = tt_ms / cc if cc > 0 else 0 + cum_per_call = ct_ms / nc if nc > 0 else 0 + base_filename = Path(filename).name + file_link = f"[link=file://{filename}]{base_filename}[/link]" + + table.add_row( + calls_str, + f"{tt_ms:.3f}", + f"{per_call:.3f}", + f"{ct_ms:.3f}", + f"{cum_per_call:.3f}", + f"{funcname} [dim]({file_link}:{lineno})[/dim]", ) - if n > 0: - times_index += 1 - new_stats.append(replaced) - console.print("\n".join(new_stats)) + console.print(Align.center(table)) + + except Exception as e: # noqa: BLE001 + console.print(f"[bold red]Error in stats processing:[/bold red] {e}") + console.print(f"Traced {self.trace_count:,} function calls") + self.total_tt = 0 def make_pstats_compatible(self) -> None: # delete the extra class_name item from the function tuple @@ -530,8 +905,8 @@ def make_pstats_compatible(self) -> None: self.stats = new_stats self.timings = new_timings - def dump_stats(self, file) -> None: - with open(file, "wb") as f: + def dump_stats(self, file: str) -> None: + with Path(file).open("wb") as f: self.create_stats() marshal.dump(self.stats, f) @@ -541,25 +916,23 @@ def create_stats(self) -> None: def snapshot_stats(self) -> None: self.stats = {} - for func, (cc, _ns, tt, ct, callers) in self.timings.items(): - callers = callers.copy() + for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items(): + callers = caller_dict.copy() nc = 0 for callcnt in callers.values(): nc += callcnt self.stats[func] = cc, nc, tt, ct, callers - def runctx(self, cmd, globals, locals): + def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None: self.__enter__() try: - exec(cmd, globals, locals) + exec(cmd, global_vars, local_vars) # noqa: S102 finally: self.__exit__(None, None, None) return self -def main(): - from argparse import ArgumentParser - +def main() -> ArgumentParser: parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", required=True) parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) From 53406de3c66c5e9fbbc2fe40d32c7359ae444cc5 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 12:06:36 -0700 Subject: [PATCH 04/13] remove spaghetti code --- codeflash/tracer.py | 559 ++++++++++++++------------------------------ 1 file changed, 169 insertions(+), 390 deletions(-) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 1bab6f0e6..9e31a5cf3 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -16,18 +16,15 @@ import json import marshal import os +import pathlib import pickle -import shutil import sqlite3 import sys -import tempfile import threading import time -import uuid -from argparse import ArgumentParser from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import dill import isort @@ -38,7 +35,7 @@ from codeflash.cli_cmds.cli import project_root_from_module_root from codeflash.cli_cmds.console import console -from codeflash.code_utils.code_utils import cleanup_paths, module_name_from_file_path +from codeflash.code_utils.code_utils import module_name_from_file_path from codeflash.code_utils.config_parser import parse_config_file from codeflash.discovery.functions_to_optimize import filter_files_optimized from codeflash.tracing.replay_test import create_trace_replay_test @@ -49,24 +46,6 @@ from types import FrameType, TracebackType -class fake_code: # noqa: N801 - def __init__(self, filename: str, line: int, name: str) -> None: - self.co_filename = filename - self.co_line = line - self.co_name = name - self.co_firstlineno = 0 - - def __repr__(self) -> str: - return repr((self.co_filename, self.co_line, self.co_name, None)) - - -class fake_frame: # noqa: N801 - def __init__(self, code: fake_code, prior: fake_frame | None) -> None: - self.f_code = code - self.f_back = prior - self.f_locals = {} - - # Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. class Tracer: """Use this class as a 'with' context manager to trace a function call. @@ -74,8 +53,6 @@ class Tracer: Traces function calls, input arguments, and profiling info. """ - used_once: ClassVar[bool] = False # Class variable to track if Tracer has been used - def __init__( self, output: str = "codeflash.trace", @@ -85,11 +62,20 @@ def __init__( max_function_count: int = 256, timeout: int | None = None, # seconds ) -> None: - """Initialize Tracer.""" + """Use this class to trace function calls. + + :param output: The path to the output trace file + :param functions: List of functions to trace. If None, trace all functions + :param disable: Disable the tracer if True + :param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered + :param max_function_count: Maximum number of times to trace one function + :param timeout: Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing + stops and normal execution continues. If this is None then no timeout applies + """ if functions is None: functions = [] if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": - console.print("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE") + console.rule("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red") disable = True self.disable = disable if self.disable: @@ -101,14 +87,8 @@ def __init__( ) self.disable = True return - - # Setup output paths + self.con = None self.output_file = Path(output).resolve() - self.output_dir = self.output_file.parent - self.output_base = self.output_file.stem - self.output_ext = self.output_file.suffix - self.thread_db_files: dict[int, Path] = {} # Thread ID to DB file path - self.functions = functions self.function_modules: list[FunctionModules] = [] self.function_count = defaultdict(int) @@ -120,7 +100,7 @@ def __init__( self.max_function_count = max_function_count self.config, found_config_path = parse_config_file(config_file_path) self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path) - console.print("project_root", self.project_root) + console.rule(f"Project Root: {self.project_root}", style="bold blue") self.ignored_functions = {"", "", "", "", "", ""} self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001 @@ -130,89 +110,43 @@ def __init__( self.next_insert = 1000 self.trace_count = 0 - self.db_lock = threading.RLock() - self.thread_local = threading.local() - - # Ensure output directory exists - self.output_dir.mkdir(parents=True, exist_ok=True) - # Profiler variables self.bias = 0 # calibration constant - self.timings: dict[Any, Any] = {} - self.cur: Any = None - self.start_time: float | None = None + self.timings = {} + self.cur = None + self.start_time = None self.timer = time.process_time_ns self.total_tt = 0 self.simulate_call("profiler") assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file" self.t = self.timer() - self.main_db_created = False # Flag to track main DB creation - - def get_thread_db_path(self) -> Path: - """Get the database path for the current thread.""" - thread_id = threading.get_ident() - if thread_id not in self.thread_db_files: - # Create a unique filename for this thread - unique_id = uuid.uuid4().hex[:8] - db_path = self.output_dir / f"{self.output_base}_{thread_id}_{unique_id}{self.output_ext}" - self.thread_db_files[thread_id] = db_path - return self.thread_db_files[thread_id] - - def get_connection(self) -> sqlite3.Connection: - """Get a dedicated connection for the current thread.""" - if not hasattr(self.thread_local, "con"): - db_path = self.get_thread_db_path() - self.thread_local.con = sqlite3.connect(db_path) - # Create the necessary tables if they don't exist - cur = self.thread_local.con.cursor() - cur.execute("""PRAGMA synchronous = OFF""") - cur.execute( - "CREATE TABLE IF NOT EXISTS function_calls(type TEXT, function TEXT, classname TEXT, " - "filename TEXT, line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" - ) - self.thread_local.con.commit() - return self.thread_local.con - - def _create_main_db_if_not_exists(self) -> None: - """Create the main output database if it doesn't exist.""" - if not self.main_db_created: # Use flag to prevent redundant checks - if not self.output_file.exists(): - try: - main_con = sqlite3.connect(self.output_file) - main_cur = main_con.cursor() - main_cur.execute("""PRAGMA synchronous = OFF""") # Added pragma for main db too - main_cur.execute( - "CREATE TABLE IF NOT EXISTS function_calls(type TEXT, function TEXT, classname TEXT, " - "filename TEXT, line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" - ) - main_con.commit() - main_con.close() - self.main_db_created = True # Set flag after successful creation - except Exception as e: # noqa: BLE001 - console.print(f"Error creating main database: {e}") - else: - self.main_db_created = True # Main DB already exists def __enter__(self) -> None: if self.disable: return - if Tracer.used_once: + if getattr(Tracer, "used_once", False): console.print( "Codeflash: Tracer can only be used once per program run. " "Please only enable the Tracer once. Skipping tracing this section." ) self.disable = True return - Tracer.used_once = True # Mark Tracer as used at the start of __enter__ - - # Clean up any existing trace files - if self.output_file.exists(): - console.print("Codeflash: Removing existing trace file") - cleanup_paths([self.output_file]) - - self._create_main_db_if_not_exists() - self.con = sqlite3.connect(self.output_file) # Keep connection open during tracing - console.print("Codeflash: Tracing started!") + Tracer.used_once = True + + if pathlib.Path(self.output_file).exists(): + console.rule("Removing existing trace file", style="bold red") + console.rule() + pathlib.Path(self.output_file).unlink(missing_ok=True) + + self.con = sqlite3.connect(self.output_file, check_same_thread=False) + cur = self.con.cursor() + cur.execute("""PRAGMA synchronous = OFF""") + cur.execute("""PRAGMA journal_mode = WAL""") + # TODO: Check out if we need to export the function test name as well + cur.execute( + "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" + ) console.rule("Program Output Begin", style="bold blue") frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 self.dispatch["call"](self, frame, 0) @@ -220,221 +154,39 @@ def __enter__(self) -> None: sys.setprofile(self.trace_callback) threading.setprofile(self.trace_callback) - def _close_thread_connection(self) -> None: - """Close thread-local connection and handle potential errors.""" - if hasattr(self.thread_local, "con") and self.thread_local.con: - try: - self.thread_local.con.commit() - self.thread_local.con.close() - del self.thread_local.con - except Exception as e: # noqa: BLE001 - console.print(f"Error closing current thread's connection: {e}") - - def _merge_thread_dbs(self) -> int: - total_rows_copied = 0 - processed_files: list[Path] = [] - - for thread_id, db_path in self.thread_db_files.items(): - if not db_path.exists(): - console.print(f"Thread database for thread {thread_id} not found, skipping.") - continue - - rows_copied = self._process_thread_db(thread_id, db_path) - if rows_copied >= 0: # _process_thread_db returns -1 on failure - total_rows_copied += rows_copied - processed_files.append(db_path) - else: - console.print(f"Failed to merge from thread database {thread_id}") - - for thread_id, db_path in self.thread_db_files.items(): - if db_path in processed_files or not db_path.exists(): - continue - - rows_copied = self._process_thread_db_with_copy(thread_id, db_path) - if rows_copied >= 0: - total_rows_copied += rows_copied - processed_files.append(db_path) - else: - console.print(f"Failed to merge from thread database {thread_id} even with copy approach.") - - return total_rows_copied - - def _process_thread_db(self, thread_id: int, db_path: Path) -> int: - try: - thread_con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=2.0) - thread_cur = thread_con.cursor() - - thread_cur.execute("SELECT * FROM function_calls") - main_cur = self.con.cursor() - - self.con.execute("BEGIN TRANSACTION") - - batch_size = 100 - batch = thread_cur.fetchmany(batch_size) - rows_processed = 0 - - while batch: - for row in batch: - try: - main_cur.execute("INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", row) - rows_processed += 1 - except sqlite3.Error as e: # noqa: PERF203 - console.print(f"Error inserting row {rows_processed} from thread {thread_id}: {e}") - batch = thread_cur.fetchmany(batch_size) - - self.con.commit() - thread_con.close() - - except sqlite3.Error as e: - console.print(f"Could not open thread database {thread_id} directly: {e}") - return -1 - else: - return rows_processed - - def _process_thread_db_with_copy(self, thread_id: int, db_path: Path) -> int: - console.print(f"Attempting file copy approach for thread {thread_id}...") - - temp_dir = tempfile.gettempdir() - temp_db_path = Path(temp_dir) / f"codeflash_temp_{uuid.uuid4().hex}.trace" - rows_processed = 0 - - try: - shutil.copy2(db_path, temp_db_path) - - temp_con = sqlite3.connect(temp_db_path) - temp_cur = temp_con.cursor() - - temp_cur.execute("SELECT COUNT(*) FROM function_calls") - row_count = temp_cur.fetchone()[0] - - if row_count > 0: - temp_cur.execute("SELECT * FROM function_calls") - main_cur = self.con.cursor() - - self.con.execute("BEGIN TRANSACTION") - batch_size = 100 - batch = temp_cur.fetchmany(batch_size) - - while batch: - for row in batch: - try: - main_cur.execute("INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", row) - rows_processed += 1 - except sqlite3.Error as e: - console.print(f"Error inserting row from thread {thread_id} copy: {e}") - - batch = temp_cur.fetchmany(batch_size) - - self.con.commit() - - temp_con.close() - cleanup_paths([temp_db_path]) - console.print(f"Successfully merged {rows_processed} rows from thread {thread_id} (via copy)") - except Exception as e: # noqa: BLE001 - console.print(f"Error with file copy approach for thread {thread_id}: {e}") - cleanup_paths([temp_db_path]) - return -1 - - else: - return rows_processed - - def _generate_stats_and_replay_test(self) -> None: - """Generate statistics, pstats compatible data, print stats and create replay test.""" - try: - self.create_stats() - - try: - main_cur = self.con.cursor() - main_cur.execute( - "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " - "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " - "cumulative_time_ns INTEGER, callers BLOB)" - ) - - for func, (cc, nc, tt, ct, callers) in self.stats.items(): - remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] - main_cur.execute( - "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - str(Path(func[0]).resolve()), - func[1], - func[2], - func[3], - cc, - nc, - tt, - ct, - json.dumps(remapped_callers), - ), - ) - - self.con.commit() # Use main DB connection - - self.make_pstats_compatible() - self.print_stats("tottime") - - main_cur.execute("CREATE TABLE total_time (time_ns INTEGER)") - main_cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) - self.con.commit() # Use main DB connection - - except Exception as e: # noqa: BLE001 - console.print(f"Error generating stats tables: {e}") - import traceback - - traceback.print_exc() - - except Exception as e: # noqa: BLE001 - console.print(f"Error during stats generation: {e}") - console.print_exception() - - # Generate the replay test - try: - replay_test = create_trace_replay_test( - trace_file=self.output_file, - functions=self.function_modules, - test_framework=self.config["test_framework"], - max_run_count=self.max_function_count, - ) - function_path = "_".join(self.functions) if self.functions else self.file_being_called_from - test_file_path = get_test_file_path( - test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" - ) - replay_test = isort.code(replay_test) - - with test_file_path.open("w", encoding="utf8") as file: - file.write(replay_test) - - console.print( - f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}", - crop=False, - soft_wrap=False, - overflow="ignore", - ) - except Exception as e: # noqa: BLE001 - console.print(f"Error creating replay test: {e}") - console.print_exception() - def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: if self.disable: return - console.rule("Program Output End", style="bold blue") sys.setprofile(None) - threading.setprofile(None) - - self._close_thread_connection() - - # Give threads time to complete their database operations - time.sleep(1) - - self._merge_thread_dbs() - self._generate_stats_and_replay_test() + self.con.commit() + console.rule("Program Output End", style="bold blue") + self.create_stats() + + cur = self.con.cursor() + cur.execute( + "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " + "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " + "cumulative_time_ns INTEGER, callers BLOB)" + ) + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] + cur.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + (str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)), + ) + self.con.commit() - all_db_paths = list(self.thread_db_files.values()) - cleanup_paths(all_db_paths) + self.make_pstats_compatible() + self.print_stats("tottime") + cur = self.con.cursor() + cur.execute("CREATE TABLE total_time (time_ns INTEGER)") + cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) + self.con.commit() + self.con.close() - # Filter any functions where we did not capture the return - moved to replay test generation for clarity + # filter any functions where we did not capture the return self.function_modules = [ function for function in self.function_modules @@ -447,9 +199,26 @@ def __exit__( > 0 ] - if self.con: - self.con.close() - self.con = None + replay_test = create_trace_replay_test( + trace_file=self.output_file, + functions=self.function_modules, + test_framework=self.config["test_framework"], + max_run_count=self.max_function_count, + ) + function_path = "_".join(self.functions) if self.functions else self.file_being_called_from + test_file_path = get_test_file_path( + test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" + ) + replay_test = isort.code(replay_test) + with open(test_file_path, "w", encoding="utf8") as file: + file.write(replay_test) + + console.print( + f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}", + crop=False, + soft_wrap=False, + overflow="ignore", + ) def tracer_logic(self, frame: FrameType, event: str) -> None: if event != "call": @@ -479,10 +248,9 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: class_name = arguments["self"].__class__.__name__ elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): class_name = arguments["cls"].__name__ - except: # noqa: E722 + except: # someone can override the getattr method and raise an exception. I'm looking at you wrapt return - function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" if function_qualified_name in self.ignored_qualified_functions: return @@ -514,9 +282,9 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: self.ignored_qualified_functions.add(function_qualified_name) return - # Get thread-specific connection - conn = self.get_connection() - cur = conn.cursor() + # TODO: Also check if this function arguments are unique from the values logged earlier + + cur = self.con.cursor() t_ns = time.perf_counter_ns() original_recursion_limit = sys.getrecursionlimit() @@ -541,32 +309,24 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # give up self.function_count[function_qualified_name] -= 1 return - try: - cur.execute( - "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", - ( - event, - code.co_name, - class_name, - str(file_name), - frame.f_lineno, - frame.f_back.__hash__(), - t_ns, - local_vars, - ), - ) - - # Add thread-safe counter increment for trace_count - with self.db_lock: - self.trace_count += 1 - - self.next_insert -= 1 - if self.next_insert == 0: - self.next_insert = 1000 - conn.commit() - except sqlite3.Error as e: - thread_id = threading.get_ident() - console.print(f"SQLite error in tracer (thread {thread_id}): {e}") + cur.execute( + "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", + ( + event, + code.co_name, + class_name, + str(file_name), + frame.f_lineno, + frame.f_back.__hash__(), + t_ns, + local_vars, + ), + ) + self.trace_count += 1 + self.next_insert -= 1 + if self.next_insert == 0: + self.next_insert = 1000 + self.con.commit() def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: # profiler section @@ -649,56 +409,58 @@ def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int: return 1 def trace_dispatch_return(self, frame: FrameType, t: int) -> int: - """Handle return events in the profiler.""" - try: - # Check if we have a valid current frame - if not self.cur or not self.cur[-2]: - return 0 + if not self.cur or not self.cur[-2]: + return 0 - # In multi-threaded environments, frames can get mismatched - if frame is not self.cur[-2]: - # Don't assert in threaded environments - frames can legitimately differ - if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: - self.trace_dispatch_return(self.cur[-2], 0) - else: - # We're in a different thread or context, can't continue with this frame - return 0 - - rpt, rit, ret, rfn, frame, rcur = self.cur - rit = rit + t - frame_total = rit + ret - - # Guard against invalid rcur (w threading) - if not rcur: + # In multi-threaded environments, frames can get mismatched + if frame is not self.cur[-2]: + # Don't assert in threaded environments - frames can legitimately differ + if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: + self.trace_dispatch_return(self.cur[-2], 0) + else: + # We're in a different thread or context, can't continue with this frame return 0 + # Prefix "r" means part of the Returning or exiting frame. + # Prefix "p" means part of the Previous or Parent or older frame. - ppt, pit, pet, pfn, pframe, pcur = rcur - self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur + rpt, rit, ret, rfn, frame, rcur = self.cur - timings = self.timings - if rfn not in timings: - # w threading, rfn can be missing - timings[rfn] = 0, 0, 0, 0, {} - - cc, ns, tt, ct, callers = timings[rfn] - if not ns: - # This is the only occurrence of the function on the stack. - ct = ct + frame_total - cc = cc + 1 - - if pfn in callers: - callers[pfn] = callers[pfn] + 1 - else: - callers[pfn] = 1 + # Guard against invalid rcur (w threading) + if not rcur: + return 0 - timings[rfn] = cc, ns - 1, tt + rit, ct, callers + rit = rit + t + frame_total = rit + ret - return 1 - except Exception: - # Handle errors gracefully - return 0 + ppt, pit, pet, pfn, pframe, pcur = rcur + self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur + + timings = self.timings + if rfn not in timings: + # w threading, rfn can be missing + timings[rfn] = 0, 0, 0, 0, {} + cc, ns, tt, ct, callers = timings[rfn] + if not ns: + # This is the only occurrence of the function on the stack. + # Else this is a (directly or indirectly) recursive call, and + # its cumulative time will get updated when the topmost call to + # it returns. + ct = ct + frame_total + cc = cc + 1 + + if pfn in callers: + callers[pfn] = callers[pfn] + 1 # hack: gather more + # stats such as the amount of time added to ct courtesy + # of this specific call, and the contribution to cc + # courtesy of this call. + else: + callers[pfn] = 1 + + timings[rfn] = cc, ns - 1, tt + rit, ct, callers - dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { + return 1 + + dispatch: ClassVar[dict[str, callable]] = { "call": trace_dispatch_call, "exception": trace_dispatch_exception, "return": trace_dispatch_return, @@ -707,10 +469,26 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int: "c_return": trace_dispatch_return, } + class fake_code: + def __init__(self, filename, line, name) -> None: + self.co_filename = filename + self.co_line = line + self.co_name = name + self.co_firstlineno = 0 + + def __repr__(self) -> str: + return repr((self.co_filename, self.co_line, self.co_name, None)) + + class fake_frame: + def __init__(self, code, prior) -> None: + self.f_code = code + self.f_back = prior + self.f_locals = {} + def simulate_call(self, name) -> None: - code = fake_code("profiler", 0, name) + code = self.fake_code("profiler", 0, name) pframe = self.cur[-2] if self.cur else None - frame = fake_frame(code, pframe) + frame = self.fake_frame(code, pframe) self.dispatch["call"](self, frame, 0) def simulate_cmd_complete(self) -> None: @@ -907,7 +685,6 @@ def make_pstats_compatible(self) -> None: def dump_stats(self, file: str) -> None: with Path(file).open("wb") as f: - self.create_stats() marshal.dump(self.stats, f) def create_stats(self) -> None: @@ -932,7 +709,9 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An return self -def main() -> ArgumentParser: +def main(): + from argparse import ArgumentParser + parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", required=True) parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) From e9539c1561ed458bb23e4dde3917a9eda3d2936f Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 12:13:08 -0700 Subject: [PATCH 05/13] ruff format --- codeflash/tracer.py | 60 ++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 9e31a5cf3..960003f40 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -22,9 +22,10 @@ import sys import threading import time +from argparse import ArgumentParser from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar import dill import isort @@ -46,6 +47,24 @@ from types import FrameType, TracebackType +class FakeCode: + def __init__(self, filename: str, line: int, name: str) -> None: + self.co_filename = filename + self.co_line = line + self.co_name = name + self.co_firstlineno = 0 + + def __repr__(self) -> str: + return repr((self.co_filename, self.co_line, self.co_name, None)) + + +class FakeFrame: + def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None: + self.f_code = code + self.f_back = prior + self.f_locals: dict = {} + + # Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. class Tracer: """Use this class as a 'with' context manager to trace a function call. @@ -75,7 +94,9 @@ def __init__( if functions is None: functions = [] if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": - console.rule("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red") + console.rule( + "Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red" + ) disable = True self.disable = disable if self.disable: @@ -210,7 +231,8 @@ def __exit__( test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" ) replay_test = isort.code(replay_test) - with open(test_file_path, "w", encoding="utf8") as file: + + with Path(test_file_path).open("w", encoding="utf8") as file: file.write(replay_test) console.print( @@ -248,7 +270,7 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: class_name = arguments["self"].__class__.__name__ elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): class_name = arguments["cls"].__name__ - except: + except: # noqa: E722 # someone can override the getattr method and raise an exception. I'm looking at you wrapt return function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" @@ -354,7 +376,7 @@ def trace_dispatch_call(self, frame: FrameType, t: int) -> int: # Only attempt to handle the frame mismatch if we have a valid rframe if ( - not isinstance(rframe, Tracer.fake_frame) + not isinstance(rframe, FakeFrame) and hasattr(rframe, "f_back") and hasattr(frame, "f_back") and rframe.f_back is frame.f_back @@ -460,7 +482,7 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int: return 1 - dispatch: ClassVar[dict[str, callable]] = { + dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { "call": trace_dispatch_call, "exception": trace_dispatch_exception, "return": trace_dispatch_return, @@ -469,26 +491,10 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int: "c_return": trace_dispatch_return, } - class fake_code: - def __init__(self, filename, line, name) -> None: - self.co_filename = filename - self.co_line = line - self.co_name = name - self.co_firstlineno = 0 - - def __repr__(self) -> str: - return repr((self.co_filename, self.co_line, self.co_name, None)) - - class fake_frame: - def __init__(self, code, prior) -> None: - self.f_code = code - self.f_back = prior - self.f_locals = {} - - def simulate_call(self, name) -> None: - code = self.fake_code("profiler", 0, name) + def simulate_call(self, name: str) -> None: + code = FakeCode("profiler", 0, name) pframe = self.cur[-2] if self.cur else None - frame = self.fake_frame(code, pframe) + frame = FakeFrame(code, pframe) self.dispatch["call"](self, frame, 0) def simulate_cmd_complete(self) -> None: @@ -709,9 +715,7 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An return self -def main(): - from argparse import ArgumentParser - +def main() -> ArgumentParser: parser = ArgumentParser(allow_abbrev=False) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to ", required=True) parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) From 7f2167ad506b80b3730b1b8e3a098d0f9f3a4236 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 13:09:03 -0700 Subject: [PATCH 06/13] integrate testbench as E2E replay test --- ...d-to-end-test-tracer-replay_testbench.yaml | 41 ++++++++++++++ .../simple_tracer_e2e/testbench.py | 24 +++++++++ .../scripts/end_to_end_test_tracer_replay.py | 4 +- ...end_to_end_test_tracer_replay_testbench.py | 25 +++++++++ tests/scripts/end_to_end_test_utilities.py | 54 ++++++++++++++++++- 5 files changed, 145 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/end-to-end-test-tracer-replay_testbench.yaml create mode 100644 code_to_optimize/code_directories/simple_tracer_e2e/testbench.py create mode 100644 tests/scripts/end_to_end_test_tracer_replay_testbench.py diff --git a/.github/workflows/end-to-end-test-tracer-replay_testbench.yaml b/.github/workflows/end-to-end-test-tracer-replay_testbench.yaml new file mode 100644 index 000000000..b0148a345 --- /dev/null +++ b/.github/workflows/end-to-end-test-tracer-replay_testbench.yaml @@ -0,0 +1,41 @@ +name: end-to-end-test + +on: + pull_request: + workflow_dispatch: + +jobs: + tracer-replay-testbench: + runs-on: ubuntu-latest + env: + CODEFLASH_AIS_SERVER: prod + POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} + CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} + COLUMNS: 110 + MAX_RETRIES: 3 + RETRY_DELAY: 5 + EXPECTED_IMPROVEMENT_PCT: 10 + CODEFLASH_END_TO_END: 1 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Python 3.11 for CLI + uses: astral-sh/setup-uv@v5 + with: + python-version: 3.11.6 + + - name: Install dependencies (CLI) + run: | + uv tool install poetry + uv venv + source .venv/bin/activate + poetry install --with dev + + - name: Run Codeflash to optimize code + id: optimize_code + run: | + source .venv/bin/activate + poetry run python tests/scripts/end_to_end_test_tracer_replay_testbench.py \ No newline at end of file diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/testbench.py b/code_to_optimize/code_directories/simple_tracer_e2e/testbench.py new file mode 100644 index 000000000..d8a43895e --- /dev/null +++ b/code_to_optimize/code_directories/simple_tracer_e2e/testbench.py @@ -0,0 +1,24 @@ +from concurrent.futures import ThreadPoolExecutor + +def funcA(number): + k = 0 + for i in range(number * 100): + k += i + # Simplify the for loop by using sum with a range object + j = sum(range(number)) + + # Use a generator expression directly in join for more efficiency + return " ".join(str(i) for i in range(number)) + + +def test_threadpool() -> None: + pool = ThreadPoolExecutor(max_workers=3) + args = list(range(10, 31, 10)) + result = pool.map(funcA, args) + + for r in result: + print(r) + + +if __name__ == "__main__": + test_threadpool() diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 03c778be9..9185dce21 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -7,10 +7,11 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( trace_mode=True, + trace_load="workload", min_improvement_x=0.1, expected_unit_tests=1, coverage_expectations=[ - CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[2, 3, 4, 6, 9]) + CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[2, 3, 4, 6, 9]), ], ) cwd = ( @@ -18,6 +19,5 @@ def run_test(expected_improvement_pct: int) -> bool: ).resolve() return run_codeflash_command(cwd, config, expected_improvement_pct) - if __name__ == "__main__": exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) diff --git a/tests/scripts/end_to_end_test_tracer_replay_testbench.py b/tests/scripts/end_to_end_test_tracer_replay_testbench.py new file mode 100644 index 000000000..9524d2c44 --- /dev/null +++ b/tests/scripts/end_to_end_test_tracer_replay_testbench.py @@ -0,0 +1,25 @@ +import os +import pathlib + +from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries + + +def run_test(expected_improvement_pct: int) -> bool: + config = TestConfig( + trace_mode=True, + trace_load="testbench", + min_improvement_x=0.1, + expected_unit_tests=1, + coverage_expectations=[ + CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[4, 5, 6, 8, 11]) + ], + ) + cwd = ( + pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "simple_tracer_e2e" + ).resolve() + return run_codeflash_command(cwd, config, expected_improvement_pct) + + + +if __name__ == "__main__": + exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 23a67a84a..43751dc38 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -25,6 +25,7 @@ class TestConfig: expected_unit_tests: Optional[int] = None min_improvement_x: float = 0.1 trace_mode: bool = False + trace_load: str = "workload" coverage_expectations: list[CoverageExpectation] = field(default_factory=list) @@ -80,7 +81,10 @@ def run_codeflash_command( ) -> bool: logging.basicConfig(level=logging.INFO) if config.trace_mode: - return run_trace_test(cwd, config, expected_improvement_pct) + if config.trace_load == "workload": + return run_trace_test(cwd, config, expected_improvement_pct) + if config.trace_load == "testbench": + return run_trace_test2(cwd, config, expected_improvement_pct) path_to_file = cwd / config.file_path file_contents = path_to_file.read_text("utf-8") @@ -228,6 +232,54 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p return validate_output(stdout, return_code, expected_improvement_pct, config) +def run_trace_test2(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool: + # First command: Run the tracer + test_root = cwd / "tests" / (config.test_framework or "") + clear_directory(test_root) + command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", "testbench.py"] + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() + ) + + output = [] + for line in process.stdout: + logging.info(line.strip()) + output.append(line) + + return_code = process.wait() + stdout = "".join(output) + + if return_code != 0: + logging.error(f"Tracer command returned exit code {return_code}") + return False + + functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) + if not functions_traced or int(functions_traced.group(1)) != 5: + logging.error("Expected 5 traced functions") + return False + + replay_test_path = pathlib.Path(functions_traced.group(2)) + if not replay_test_path.exists(): + logging.error(f"Replay test file missing at {replay_test_path}") + return False + + # Second command: Run optimization + command = ["python", "../../../codeflash/main.py", "--replay-test", str(replay_test_path), "--no-pr"] + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() + ) + + output = [] + for line in process.stdout: + logging.info(line.strip()) + output.append(line) + + return_code = process.wait() + stdout = "".join(output) + + return validate_output(stdout, return_code, expected_improvement_pct, config) + + def run_with_retries(test_func, *args, **kwargs) -> bool: max_retries = int(os.getenv("MAX_RETRIES", 3)) retry_delay = int(os.getenv("RETRY_DELAY", 5)) From a2d4c4bcce99ab79f7d0df8b74cea6d1caef5cf2 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 14:56:47 -0700 Subject: [PATCH 07/13] consolidate logic around tracer --- tests/scripts/end_to_end_test_utilities.py | 59 +++------------------- 1 file changed, 6 insertions(+), 53 deletions(-) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 43751dc38..3d13702f6 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -81,10 +81,7 @@ def run_codeflash_command( ) -> bool: logging.basicConfig(level=logging.INFO) if config.trace_mode: - if config.trace_load == "workload": - return run_trace_test(cwd, config, expected_improvement_pct) - if config.trace_load == "testbench": - return run_trace_test2(cwd, config, expected_improvement_pct) + return run_trace_test(cwd, config, expected_improvement_pct) path_to_file = cwd / config.file_path file_contents = path_to_file.read_text("utf-8") @@ -188,55 +185,11 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p # First command: Run the tracer test_root = cwd / "tests" / (config.test_framework or "") clear_directory(test_root) - command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", "workload.py"] - process = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() - ) - - output = [] - for line in process.stdout: - logging.info(line.strip()) - output.append(line) - - return_code = process.wait() - stdout = "".join(output) - - if return_code != 0: - logging.error(f"Tracer command returned exit code {return_code}") - return False - - functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) - if not functions_traced or int(functions_traced.group(1)) != 3: - logging.error("Expected 3 traced functions") - return False - - replay_test_path = pathlib.Path(functions_traced.group(2)) - if not replay_test_path.exists(): - logging.error(f"Replay test file missing at {replay_test_path}") - return False - # Second command: Run optimization - command = ["python", "../../../codeflash/main.py", "--replay-test", str(replay_test_path), "--no-pr"] - process = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() - ) + trace_script = "workload.py" if config.trace_load == "workload" else "testbench.py" + expected_traced_functions = 3 if config.trace_load == "workload" else 5 - output = [] - for line in process.stdout: - logging.info(line.strip()) - output.append(line) - - return_code = process.wait() - stdout = "".join(output) - - return validate_output(stdout, return_code, expected_improvement_pct, config) - - -def run_trace_test2(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool: - # First command: Run the tracer - test_root = cwd / "tests" / (config.test_framework or "") - clear_directory(test_root) - command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", "testbench.py"] + command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", trace_script] process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() ) @@ -254,8 +207,8 @@ def run_trace_test2(cwd: pathlib.Path, config: TestConfig, expected_improvement_ return False functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) - if not functions_traced or int(functions_traced.group(1)) != 5: - logging.error("Expected 5 traced functions") + if not functions_traced or int(functions_traced.group(1)) != expected_traced_functions: + logging.error(f"Expected {expected_traced_functions} traced functions") return False replay_test_path = pathlib.Path(functions_traced.group(2)) From 122f9205ec9c088da3e16a8eb26ac5057513b736 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 15:00:11 -0700 Subject: [PATCH 08/13] workload expects 4, not 5 funcs --- tests/scripts/end_to_end_test_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 3d13702f6..12e4e3eb8 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -187,7 +187,7 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p clear_directory(test_root) trace_script = "workload.py" if config.trace_load == "workload" else "testbench.py" - expected_traced_functions = 3 if config.trace_load == "workload" else 5 + expected_traced_functions = 3 if config.trace_load == "workload" else 4 command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", trace_script] process = subprocess.Popen( From b383d094c48c4f923c0f891791b82d231233c239 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 15:16:19 -0700 Subject: [PATCH 09/13] single tracer test --- .../simple_tracer_e2e/testbench.py | 24 ------------------ .../simple_tracer_e2e/workload.py | 12 +++++++-- ...end_to_end_test_tracer_replay_testbench.py | 25 ------------------- tests/scripts/end_to_end_test_utilities.py | 13 +++------- 4 files changed, 14 insertions(+), 60 deletions(-) delete mode 100644 code_to_optimize/code_directories/simple_tracer_e2e/testbench.py delete mode 100644 tests/scripts/end_to_end_test_tracer_replay_testbench.py diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/testbench.py b/code_to_optimize/code_directories/simple_tracer_e2e/testbench.py deleted file mode 100644 index d8a43895e..000000000 --- a/code_to_optimize/code_directories/simple_tracer_e2e/testbench.py +++ /dev/null @@ -1,24 +0,0 @@ -from concurrent.futures import ThreadPoolExecutor - -def funcA(number): - k = 0 - for i in range(number * 100): - k += i - # Simplify the for loop by using sum with a range object - j = sum(range(number)) - - # Use a generator expression directly in join for more efficiency - return " ".join(str(i) for i in range(number)) - - -def test_threadpool() -> None: - pool = ThreadPoolExecutor(max_workers=3) - args = list(range(10, 31, 10)) - result = pool.map(funcA, args) - - for r in result: - print(r) - - -if __name__ == "__main__": - test_threadpool() diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py index 053e25904..1fe6af823 100644 --- a/code_to_optimize/code_directories/simple_tracer_e2e/workload.py +++ b/code_to_optimize/code_directories/simple_tracer_e2e/workload.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor def funcA(number): k = 0 for i in range(number * 100): @@ -8,7 +9,14 @@ def funcA(number): # Use a generator expression directly in join for more efficiency return " ".join(str(i) for i in range(number)) +def test_threadpool() -> None: + pool = ThreadPoolExecutor(max_workers=3) + args = list(range(10, 31, 10)) + result = pool.map(funcA, args) + + for r in result: + print(r) + if __name__ == "__main__": - for i in range(10, 31, 10): - funcA(10) + test_threadpool() \ No newline at end of file diff --git a/tests/scripts/end_to_end_test_tracer_replay_testbench.py b/tests/scripts/end_to_end_test_tracer_replay_testbench.py deleted file mode 100644 index 9524d2c44..000000000 --- a/tests/scripts/end_to_end_test_tracer_replay_testbench.py +++ /dev/null @@ -1,25 +0,0 @@ -import os -import pathlib - -from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries - - -def run_test(expected_improvement_pct: int) -> bool: - config = TestConfig( - trace_mode=True, - trace_load="testbench", - min_improvement_x=0.1, - expected_unit_tests=1, - coverage_expectations=[ - CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[4, 5, 6, 8, 11]) - ], - ) - cwd = ( - pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "simple_tracer_e2e" - ).resolve() - return run_codeflash_command(cwd, config, expected_improvement_pct) - - - -if __name__ == "__main__": - exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 12e4e3eb8..ec4a25382 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -25,7 +25,6 @@ class TestConfig: expected_unit_tests: Optional[int] = None min_improvement_x: float = 0.1 trace_mode: bool = False - trace_load: str = "workload" coverage_expectations: list[CoverageExpectation] = field(default_factory=list) @@ -185,11 +184,7 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p # First command: Run the tracer test_root = cwd / "tests" / (config.test_framework or "") clear_directory(test_root) - - trace_script = "workload.py" if config.trace_load == "workload" else "testbench.py" - expected_traced_functions = 3 if config.trace_load == "workload" else 4 - - command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", trace_script] + command = ["python", "-m", "codeflash.tracer", "-o", "codeflash.trace", "workload.py"] process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() ) @@ -207,8 +202,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p return False functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) - if not functions_traced or int(functions_traced.group(1)) != expected_traced_functions: - logging.error(f"Expected {expected_traced_functions} traced functions") + if not functions_traced or int(functions_traced.group(1)) != 3: + logging.error("Expected 3 traced functions") return False replay_test_path = pathlib.Path(functions_traced.group(2)) @@ -254,4 +249,4 @@ def run_with_retries(test_func, *args, **kwargs) -> bool: logging.error("Test failed after all retries") return 1 - return 1 + return 1 \ No newline at end of file From b95822a36607392846d4ba911d9cc3454ce028bf Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 15:18:05 -0700 Subject: [PATCH 10/13] cleanup --- ...d-to-end-test-tracer-replay_testbench.yaml | 41 ------------------- .../scripts/end_to_end_test_tracer_replay.py | 1 - 2 files changed, 42 deletions(-) delete mode 100644 .github/workflows/end-to-end-test-tracer-replay_testbench.yaml diff --git a/.github/workflows/end-to-end-test-tracer-replay_testbench.yaml b/.github/workflows/end-to-end-test-tracer-replay_testbench.yaml deleted file mode 100644 index b0148a345..000000000 --- a/.github/workflows/end-to-end-test-tracer-replay_testbench.yaml +++ /dev/null @@ -1,41 +0,0 @@ -name: end-to-end-test - -on: - pull_request: - workflow_dispatch: - -jobs: - tracer-replay-testbench: - runs-on: ubuntu-latest - env: - CODEFLASH_AIS_SERVER: prod - POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} - CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} - COLUMNS: 110 - MAX_RETRIES: 3 - RETRY_DELAY: 5 - EXPECTED_IMPROVEMENT_PCT: 10 - CODEFLASH_END_TO_END: 1 - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up Python 3.11 for CLI - uses: astral-sh/setup-uv@v5 - with: - python-version: 3.11.6 - - - name: Install dependencies (CLI) - run: | - uv tool install poetry - uv venv - source .venv/bin/activate - poetry install --with dev - - - name: Run Codeflash to optimize code - id: optimize_code - run: | - source .venv/bin/activate - poetry run python tests/scripts/end_to_end_test_tracer_replay_testbench.py \ No newline at end of file diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 9185dce21..0bd69f431 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -7,7 +7,6 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( trace_mode=True, - trace_load="workload", min_improvement_x=0.1, expected_unit_tests=1, coverage_expectations=[ From ab5409fc5ac9e911e074986ecd9eb29665496c65 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 15:30:34 -0700 Subject: [PATCH 11/13] adjust traced expectations --- tests/scripts/end_to_end_test_utilities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index ec4a25382..c961b6fd1 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -202,8 +202,8 @@ def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_p return False functions_traced = re.search(r"Traced (\d+) function calls successfully and replay test created at - (.*)$", stdout) - if not functions_traced or int(functions_traced.group(1)) != 3: - logging.error("Expected 3 traced functions") + if not functions_traced or int(functions_traced.group(1)) != 5: + logging.error("Expected 5 traced functions") return False replay_test_path = pathlib.Path(functions_traced.group(2)) From e93bd84522e0ada815f89ad2cbff072b3e4cfd64 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 15:38:52 -0700 Subject: [PATCH 12/13] Update tracer.py --- codeflash/tracer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 960003f40..eb4df84d4 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -168,7 +168,7 @@ def __enter__(self) -> None: "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" ) - console.rule("Program Output Begin", style="bold blue") + console.rule("Codeflash: Traced Program Output Begin", style="bold blue") frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 self.dispatch["call"](self, frame, 0) self.start_time = time.time() @@ -182,7 +182,7 @@ def __exit__( return sys.setprofile(None) self.con.commit() - console.rule("Program Output End", style="bold blue") + console.rule("Codeflash: Traced Program Output End", style="bold blue") self.create_stats() cur = self.con.cursor() From c12071e2c8dded60503a2982e9187f3e92ffb192 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 11 Mar 2025 15:45:42 -0700 Subject: [PATCH 13/13] adjust coverage expectations with the added import, and the moved func def, the coverage lines change, inspected manually --- tests/scripts/end_to_end_test_tracer_replay.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 0bd69f431..58662448e 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -10,7 +10,7 @@ def run_test(expected_improvement_pct: int) -> bool: min_improvement_x=0.1, expected_unit_tests=1, coverage_expectations=[ - CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[2, 3, 4, 6, 9]), + CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[3, 4, 5, 7, 10]), ], ) cwd = (