diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index db78a9415..545ea4a0a 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -341,7 +341,7 @@ def discover_unit_tests( cfg: TestConfig, discover_only_these_tests: list[Path] | None = None, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None, -) -> tuple[dict[str, set[FunctionCalledInTest]], int]: +) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]: framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} strategy = framework_strategies.get(cfg.test_framework, None) if not strategy: @@ -352,8 +352,10 @@ def discover_unit_tests( functions_to_optimize = None if file_to_funcs_to_optimize: functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list] - function_to_tests, num_discovered_tests = strategy(cfg, discover_only_these_tests, functions_to_optimize) - return function_to_tests, num_discovered_tests + function_to_tests, num_discovered_tests, num_discovered_replay_tests = strategy( + cfg, discover_only_these_tests, functions_to_optimize + ) + return function_to_tests, num_discovered_tests, num_discovered_replay_tests def discover_tests_pytest( @@ -515,6 +517,7 @@ def process_test_files( function_to_test_map = defaultdict(set) num_discovered_tests = 0 + num_discovered_replay_tests = 0 jedi_project = jedi.Project(path=project_root_path) with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( @@ -661,6 +664,9 @@ def process_test_files( position=CodePosition(line_no=name.line, col_no=name.column), ) ) + if test_func.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 + num_discovered_tests += 1 except Exception as e: logger.debug(str(e)) @@ -668,4 +674,4 @@ def process_test_files( progress.advance(task_id) - return dict(function_to_test_map), num_discovered_tests + return dict(function_to_test_map), num_discovered_tests, num_discovered_replay_tests diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 357d61537..4fcafc50b 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -304,7 +304,7 @@ def get_all_replay_test_functions( logger.error("Could not find trace_file_path in replay test files.") exit_with_message("Could not find trace_file_path in replay test files.") - function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test) + function_tests, _, _ = discover_unit_tests(test_cfg, discover_only_these_tests=replay_test) # Get the absolute file paths for each function, excluding class name if present filtered_valid_functions = defaultdict(list) file_to_functions_map = defaultdict(list) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index f794fc01b..63a32199e 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -226,12 +226,12 @@ def discover_tests( console.rule() start_time = time.time() - function_to_tests, num_discovered_tests = discover_unit_tests( + function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests( self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize ) console.rule() logger.info( - f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" + f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" ) console.rule() ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 85c6f3b68..46c73f819 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -11,792 +11,22 @@ # from __future__ import annotations -import contextlib -import datetime -import importlib.machinery -import io import json -import marshal -import os -import pathlib import pickle -import sqlite3 +import subprocess import sys -import threading -import time from argparse import ArgumentParser -from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar - -import isort -from rich.align import Align -from rich.panel import Panel -from rich.table import Table -from rich.text import Text +from typing import TYPE_CHECKING from codeflash.cli_cmds.cli import project_root_from_module_root from codeflash.cli_cmds.console import console -from codeflash.code_utils.code_utils import module_name_from_file_path +from codeflash.code_utils.code_utils import get_run_tmp_file +from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.code_utils.config_parser import parse_config_file -from codeflash.discovery.functions_to_optimize import filter_files_optimized -from codeflash.picklepatch.pickle_patcher import PicklePatcher -from codeflash.tracing.replay_test import create_trace_replay_test -from codeflash.tracing.tracing_utils import FunctionModules -from codeflash.verification.verification_utils import get_test_file_path if TYPE_CHECKING: from argparse import Namespace - from types import FrameType, TracebackType - - -class FakeCode: - def __init__(self, filename: str, line: int, name: str) -> None: - self.co_filename = filename - self.co_line = line - self.co_name = name - self.co_firstlineno = 0 - - def __repr__(self) -> str: - return repr((self.co_filename, self.co_line, self.co_name, None)) - - -class FakeFrame: - def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None: - self.f_code = code - self.f_back = prior - self.f_locals: dict = {} - - -# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. -class Tracer: - """Use this class as a 'with' context manager to trace a function call. - - Traces function calls, input arguments, and profiling info. - """ - - def __init__( - self, - output: str = "codeflash.trace", - functions: list[str] | None = None, - disable: bool = False, # noqa: FBT001, FBT002 - config_file_path: Path | None = None, - max_function_count: int = 256, - timeout: int | None = None, # seconds - command: str | None = None, - ) -> None: - """Use this class to trace function calls. - - :param output: The path to the output trace file - :param functions: List of functions to trace. If None, trace all functions - :param disable: Disable the tracer if True - :param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered - :param max_function_count: Maximum number of times to trace one function - :param timeout: Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing - stops and normal execution continues. If this is None then no timeout applies - :param command: The command that initiated the tracing (for metadata storage) - """ - if functions is None: - functions = [] - if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": - console.rule( - "Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red" - ) - disable = True - self.disable = disable - self._db_lock: threading.Lock | None = None - if self.disable: - return - if sys.getprofile() is not None or sys.gettrace() is not None: - console.print( - "WARNING - Codeflash: Another profiler, debugger or coverage tool is already running. " - "Please disable it before starting the Codeflash Tracer, both can't run. Codeflash Tracer is DISABLED." - ) - self.disable = True - return - - self._db_lock = threading.Lock() - - self.con = None - self.output_file = Path(output).resolve() - self.functions = functions - self.function_modules: list[FunctionModules] = [] - self.function_count = defaultdict(int) - self.current_file_path = Path(__file__).resolve() - self.ignored_qualified_functions = { - f"{self.current_file_path}:Tracer.__exit__", - f"{self.current_file_path}:Tracer.__enter__", - } - self.max_function_count = max_function_count - self.config, found_config_path = parse_config_file(config_file_path) - self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path) - console.rule(f"Project Root: {self.project_root}", style="bold blue") - self.ignored_functions = {"", "", "", "", "", ""} - - self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001 - self.replay_test_file_path: Path | None = None - - assert timeout is None or timeout > 0, "Timeout should be greater than 0" - self.timeout = timeout - self.next_insert = 1000 - self.trace_count = 0 - self.path_cache = {} # Cache for resolved file paths - - # Profiler variables - self.bias = 0 # calibration constant - self.timings = {} - self.cur = None - self.start_time = None - self.timer = time.process_time_ns - self.total_tt = 0 - self.simulate_call("profiler") - assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file" - self.t = self.timer() - - # Store command information for metadata table - self.command = command if command else " ".join(sys.argv) - - def __enter__(self) -> None: - if self.disable: - return - if getattr(Tracer, "used_once", False): - console.print( - "Codeflash: Tracer can only be used once per program run. " - "Please only enable the Tracer once. Skipping tracing this section." - ) - self.disable = True - return - Tracer.used_once = True - - if pathlib.Path(self.output_file).exists(): - console.rule("Removing existing trace file", style="bold red") - console.rule() - pathlib.Path(self.output_file).unlink(missing_ok=True) - - self.con = sqlite3.connect(self.output_file, check_same_thread=False) - cur = self.con.cursor() - cur.execute("""PRAGMA synchronous = OFF""") - cur.execute("""PRAGMA journal_mode = WAL""") - # TODO: Check out if we need to export the function test name as well - cur.execute( - "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " - "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" - ) - - # Create metadata table to store command information - cur.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") - - # Store command metadata - cur.execute("INSERT INTO metadata VALUES (?, ?)", ("command", self.command)) - cur.execute("INSERT INTO metadata VALUES (?, ?)", ("program_name", self.file_being_called_from)) - cur.execute( - "INSERT INTO metadata VALUES (?, ?)", - ("functions_filter", json.dumps(self.functions) if self.functions else None), - ) - cur.execute( - "INSERT INTO metadata VALUES (?, ?)", - ("timestamp", datetime.datetime.now(datetime.timezone.utc).isoformat()), - ) - cur.execute("INSERT INTO metadata VALUES (?, ?)", ("project_root", str(self.project_root))) - console.rule("Codeflash: Traced Program Output Begin", style="bold blue") - frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 - self.dispatch["call"](self, frame, 0) - self.start_time = time.time() - sys.setprofile(self.trace_callback) - threading.setprofile(self.trace_callback) - - def __exit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None - ) -> None: - if self.disable or self._db_lock is None: - return - sys.setprofile(None) - threading.setprofile(None) - - with self._db_lock: - if self.con is None: - return - - self.con.commit() # Commit any pending from tracer_logic - console.rule("Codeflash: Traced Program Output End", style="bold blue") - self.create_stats() # This calls snapshot_stats which uses self.timings - - cur = self.con.cursor() - cur.execute( - "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " - "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " - "cumulative_time_ns INTEGER, callers BLOB)" - ) - # self.stats is populated by snapshot_stats() called within create_stats() - # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data - # For now, assuming self.stats is primarily in-memory after create_stats() - for func, (cc, nc, tt, ct, callers) in self.stats.items(): - remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] - cur.execute( - "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - str(Path(func[0]).resolve()), - func[1], - func[2], - func[3], - cc, - nc, - tt, - ct, - json.dumps(remapped_callers), - ), - ) - self.con.commit() - - self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory - self.print_stats("tottime") # Uses self.stats, prints to console - - cur = self.con.cursor() # New cursor - cur.execute("CREATE TABLE total_time (time_ns INTEGER)") - cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) - self.con.commit() - self.con.close() - self.con = None # Mark connection as closed - - # filter any functions where we did not capture the return - self.function_modules = [ - function - for function in self.function_modules - if self.function_count[ - str(function.file_name) - + ":" - + (function.class_name + "." if function.class_name else "") - + function.function_name - ] - > 0 - ] - - replay_test = create_trace_replay_test( - trace_file=self.output_file, - functions=self.function_modules, - test_framework=self.config["test_framework"], - max_run_count=self.max_function_count, - ) - function_path = "_".join(self.functions) if self.functions else self.file_being_called_from - test_file_path = get_test_file_path( - test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" - ) - replay_test = isort.code(replay_test) - - with Path(test_file_path).open("w", encoding="utf8") as file: - file.write(replay_test) - self.replay_test_file_path = test_file_path - - console.print( - f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}", - crop=False, - soft_wrap=False, - overflow="ignore", - ) - - def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 - if event != "call": - return - if None is not self.timeout and (time.time() - self.start_time) > self.timeout: - sys.setprofile(None) - threading.setprofile(None) - console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.") - return - if self.disable or self._db_lock is None or self.con is None: - return - - code = frame.f_code - - # Check function name first before resolving path - if code.co_name in self.ignored_functions: - return - - # Now resolve file path only if we need it - co_filename = code.co_filename - if co_filename in self.path_cache: - file_name = self.path_cache[co_filename] - else: - file_name = Path(co_filename).resolve() - self.path_cache[co_filename] = file_name - # TODO : It currently doesn't log the last return call from the first function - - if not file_name.is_relative_to(self.project_root): - return - if not file_name.exists(): - return - if self.functions and code.co_name not in self.functions: - return - class_name = None - arguments = frame.f_locals - try: - self_arg = arguments.get("self") - if self_arg is not None: - try: - class_name = self_arg.__class__.__name__ - except AttributeError: - cls_arg = arguments.get("cls") - if cls_arg is not None: - with contextlib.suppress(AttributeError): - class_name = cls_arg.__name__ - else: - cls_arg = arguments.get("cls") - if cls_arg is not None: - with contextlib.suppress(AttributeError): - class_name = cls_arg.__name__ - except: # noqa: E722 - # someone can override the getattr method and raise an exception. I'm looking at you wrapt - return - - # Extract class name from co_qualname for static methods that lack self/cls - if class_name is None and "." in getattr(code, "co_qualname", ""): - qualname_parts = code.co_qualname.split(".") - if len(qualname_parts) >= 2: - class_name = qualname_parts[-2] - - try: - function_qualified_name = f"{file_name}:{code.co_qualname}" - except AttributeError: - function_qualified_name = f"{file_name}:{(class_name + '.' if class_name else '')}{code.co_name}" - if function_qualified_name in self.ignored_qualified_functions: - return - if function_qualified_name not in self.function_count: - # seeing this function for the first time - self.function_count[function_qualified_name] = 1 - file_valid = filter_files_optimized( - file_path=file_name, - tests_root=Path(self.config["tests_root"]), - ignore_paths=[Path(p) for p in self.config["ignore_paths"]], - module_root=Path(self.config["module_root"]), - ) - if not file_valid: - # we don't want to trace this function because it cannot be optimized - self.ignored_qualified_functions.add(function_qualified_name) - return - self.function_modules.append( - FunctionModules( - function_name=code.co_name, - file_name=file_name, - module_name=module_name_from_file_path(file_name, project_root_path=self.project_root), - class_name=class_name, - line_no=code.co_firstlineno, - ) - ) - else: - self.function_count[function_qualified_name] += 1 - if self.function_count[function_qualified_name] >= self.max_function_count: - self.ignored_qualified_functions.add(function_qualified_name) - return - - # TODO: Also check if this function arguments are unique from the values logged earlier - - with self._db_lock: - # Check connection again inside lock, in case __exit__ closed it. - if self.con is None: - return - - cur = self.con.cursor() - - t_ns = time.perf_counter_ns() - original_recursion_limit = sys.getrecursionlimit() - try: - # pickling can be a recursive operator, so we need to increase the recursion limit - sys.setrecursionlimit(10000) - # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class - # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory - # leaks, bad references or side effects when unpickling. - arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals - if class_name and code.co_name == "__init__" and "self" in arguments_copy: - del arguments_copy["self"] - local_vars = PicklePatcher.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL) - sys.setrecursionlimit(original_recursion_limit) - except Exception: - self.function_count[function_qualified_name] -= 1 - sys.setrecursionlimit(original_recursion_limit) - return - - cur.execute( - "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", - ( - event, - code.co_name, - class_name, - str(file_name), - frame.f_lineno, - frame.f_back.__hash__(), - t_ns, - local_vars, - ), - ) - self.trace_count += 1 - self.next_insert -= 1 - if self.next_insert == 0: - self.next_insert = 1000 - self.con.commit() - - def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: - # profiler section - timer = self.timer - t = timer() - self.t - self.bias - if event == "c_call": - self.c_func_name = arg.__name__ - - prof_success = bool(self.dispatch[event](self, frame, t)) - # tracer section - self.tracer_logic(frame, event) - # measure the time as the last thing before return - if prof_success: - self.t = timer() - else: - self.t = timer() - t # put back unrecorded delta - - def trace_dispatch_call(self, frame: FrameType, t: int) -> int: - """Handle call events in the profiler.""" - try: - # In multi-threaded contexts, we need to be more careful about frame comparisons - if self.cur and frame.f_back is not self.cur[-2]: - # This happens when we're in a different thread - rpt, rit, ret, rfn, rframe, rcur = self.cur - - # Only attempt to handle the frame mismatch if we have a valid rframe - if ( - not isinstance(rframe, FakeFrame) - and hasattr(rframe, "f_back") - and hasattr(frame, "f_back") - and rframe.f_back is frame.f_back - ): - self.trace_dispatch_return(rframe, 0) - - # Get function information - fcode = frame.f_code - arguments = frame.f_locals - class_name = None - try: - if ( - "self" in arguments - and hasattr(arguments["self"], "__class__") - and hasattr(arguments["self"].__class__, "__name__") - ): - class_name = arguments["self"].__class__.__name__ - elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): - class_name = arguments["cls"].__name__ - except Exception: # noqa: S110 - pass - - fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name) - self.cur = (t, 0, 0, fn, frame, self.cur) - timings = self.timings - if fn in timings: - cc, ns, tt, ct, callers = timings[fn] - timings[fn] = cc, ns + 1, tt, ct, callers - else: - timings[fn] = 0, 0, 0, 0, {} - return 1 # noqa: TRY300 - except Exception: - # Handle any errors gracefully - return 0 - - def trace_dispatch_exception(self, frame: FrameType, t: int) -> int: - rpt, rit, ret, rfn, rframe, rcur = self.cur - if (rframe is not frame) and rcur: - return self.trace_dispatch_return(rframe, t) - self.cur = rpt, rit + t, ret, rfn, rframe, rcur - return 1 - - def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int: - fn = ("", 0, self.c_func_name, None) - self.cur = (t, 0, 0, fn, frame, self.cur) - timings = self.timings - if fn in timings: - cc, ns, tt, ct, callers = timings[fn] - timings[fn] = cc, ns + 1, tt, ct, callers - else: - timings[fn] = 0, 0, 0, 0, {} - return 1 - - def trace_dispatch_return(self, frame: FrameType, t: int) -> int: - if not self.cur or not self.cur[-2]: - return 0 - - # In multi-threaded environments, frames can get mismatched - if frame is not self.cur[-2]: - # Don't assert in threaded environments - frames can legitimately differ - if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: - self.trace_dispatch_return(self.cur[-2], 0) - else: - # We're in a different thread or context, can't continue with this frame - return 0 - # Prefix "r" means part of the Returning or exiting frame. - # Prefix "p" means part of the Previous or Parent or older frame. - - rpt, rit, ret, rfn, frame, rcur = self.cur - - # Guard against invalid rcur (w threading) - if not rcur: - return 0 - - rit = rit + t - frame_total = rit + ret - - ppt, pit, pet, pfn, pframe, pcur = rcur - self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur - - timings = self.timings - if rfn not in timings: - # w threading, rfn can be missing - timings[rfn] = 0, 0, 0, 0, {} - cc, ns, tt, ct, callers = timings[rfn] - if not ns: - # This is the only occurrence of the function on the stack. - # Else this is a (directly or indirectly) recursive call, and - # its cumulative time will get updated when the topmost call to - # it returns. - ct = ct + frame_total - cc = cc + 1 - - if pfn in callers: - # Increment call count between these functions - callers[pfn] = callers[pfn] + 1 - # Note: This tracks stats such as the amount of time added to ct - # of this specific call, and the contribution to cc - # courtesy of this call. - else: - callers[pfn] = 1 - - timings[rfn] = cc, ns - 1, tt + rit, ct, callers - - return 1 - - dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { - "call": trace_dispatch_call, - "exception": trace_dispatch_exception, - "return": trace_dispatch_return, - "c_call": trace_dispatch_c_call, - "c_exception": trace_dispatch_return, # the C function returned - "c_return": trace_dispatch_return, - } - - def simulate_call(self, name: str) -> None: - code = FakeCode("profiler", 0, name) - pframe = self.cur[-2] if self.cur else None - frame = FakeFrame(code, pframe) - self.dispatch["call"](self, frame, 0) - - def simulate_cmd_complete(self) -> None: - get_time = self.timer - t = get_time() - self.t - while self.cur[-1]: - # We *can* cause assertion errors here if - # dispatch_trace_return checks for a frame match! - self.dispatch["return"](self, self.cur[-2], t) - t = 0 - self.t = get_time() - t - - def print_stats(self, sort: str | int | tuple = -1) -> None: - if not self.stats: - console.print("Codeflash: No stats available to print") - self.total_tt = 0 - return - - if not isinstance(sort, tuple): - sort = (sort,) - - # First, convert stats to make them pstats-compatible - try: - # Initialize empty collections for pstats - self.files = [] - self.top_level = [] - - # Create entirely new dictionaries instead of modifying existing ones - new_stats = {} - new_timings = {} - - # Convert stats dictionary - stats_items = list(self.stats.items()) - for func, stats_data in stats_items: - try: - # Make sure we have 5 elements in stats_data - if len(stats_data) != 5: - console.print(f"Skipping malformed stats data for {func}: {stats_data}") - continue - - cc, nc, tt, ct, callers = stats_data - - if len(func) == 4: - file_name, line_num, func_name, class_name = func - new_func_name = f"{class_name}.{func_name}" if class_name else func_name - new_func = (file_name, line_num, new_func_name) - else: - new_func = func # Keep as is if already in correct format - - new_callers = {} - callers_items = list(callers.items()) - for caller_func, count in callers_items: - if isinstance(caller_func, tuple): - if len(caller_func) == 4: - caller_file, caller_line, caller_name, caller_class = caller_func - caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name - new_caller_func = (caller_file, caller_line, caller_new_name) - else: - new_caller_func = caller_func - else: - console.print(f"Unexpected caller format: {caller_func}") - new_caller_func = str(caller_func) - - new_callers[new_caller_func] = count - - # Store with new format - new_stats[new_func] = (cc, nc, tt, ct, new_callers) - except Exception as e: - console.print(f"Error converting stats for {func}: {e}") - continue - - timings_items = list(self.timings.items()) - for func, timing_data in timings_items: - try: - if len(timing_data) != 5: - console.print(f"Skipping malformed timing data for {func}: {timing_data}") - continue - - cc, ns, tt, ct, callers = timing_data - - if len(func) == 4: - file_name, line_num, func_name, class_name = func - new_func_name = f"{class_name}.{func_name}" if class_name else func_name - new_func = (file_name, line_num, new_func_name) - else: - new_func = func - - new_callers = {} - callers_items = list(callers.items()) - for caller_func, count in callers_items: - if isinstance(caller_func, tuple): - if len(caller_func) == 4: - caller_file, caller_line, caller_name, caller_class = caller_func - caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name - new_caller_func = (caller_file, caller_line, caller_new_name) - else: - new_caller_func = caller_func - else: - console.print(f"Unexpected caller format: {caller_func}") - new_caller_func = str(caller_func) - - new_callers[new_caller_func] = count - - new_timings[new_func] = (cc, ns, tt, ct, new_callers) - except Exception as e: - console.print(f"Error converting timings for {func}: {e}") - continue - - self.stats = new_stats - self.timings = new_timings - - self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values()) - - total_calls = sum(cc for cc, _, _, _, _ in self.stats.values()) - total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values()) - - summary = Text.assemble( - f"{total_calls:,} function calls ", - ("(" + f"{total_primitive:,} primitive calls" + ")", "dim"), - f" in {self.total_tt / 1e6:.3f}milliseconds", - ) - - console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False))) - - table = Table( - show_header=True, - header_style="bold magenta", - border_style="blue", - title="[bold]Function Profile[/bold] (ordered by internal time)", - title_style="cyan", - caption=f"Showing top {min(25, len(self.stats))} of {len(self.stats)} functions", - ) - - table.add_column("Calls", justify="right", style="green", width=10) - table.add_column("Time (ms)", justify="right", style="cyan", width=10) - table.add_column("Per Call", justify="right", style="cyan", width=10) - table.add_column("Cum (ms)", justify="right", style="yellow", width=10) - table.add_column("Cum/Call", justify="right", style="yellow", width=10) - table.add_column("Function", style="blue") - - sorted_stats = sorted( - ((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3), - key=lambda x: x[1][2], # Sort by tt (internal time) - reverse=True, - )[:25] # Limit to top 25 - - # Format and add each row to the table - for func, (cc, nc, tt, ct, _) in sorted_stats: - filename, lineno, funcname = func - - # Format calls - show recursive format if different - calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}" - - # Convert to milliseconds - tt_ms = tt / 1e6 - ct_ms = ct / 1e6 - - # Calculate per-call times - per_call = tt_ms / cc if cc > 0 else 0 - cum_per_call = ct_ms / nc if nc > 0 else 0 - base_filename = Path(filename).name - file_link = f"[link=file://{filename}]{base_filename}[/link]" - - table.add_row( - calls_str, - f"{tt_ms:.3f}", - f"{per_call:.3f}", - f"{ct_ms:.3f}", - f"{cum_per_call:.3f}", - f"{funcname} [dim]({file_link}:{lineno})[/dim]", - ) - - console.print(Align.center(table)) - - except Exception as e: - console.print(f"[bold red]Error in stats processing:[/bold red] {e}") - console.print(f"Traced {self.trace_count:,} function calls") - self.total_tt = 0 - - def make_pstats_compatible(self) -> None: - # delete the extra class_name item from the function tuple - self.files = [] - self.top_level = [] - new_stats = {} - for func, (cc, ns, tt, ct, callers) in self.stats.items(): - new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} - new_stats[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) - new_timings = {} - for func, (cc, ns, tt, ct, callers) in self.timings.items(): - new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} - new_timings[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) - self.stats = new_stats - self.timings = new_timings - - def dump_stats(self, file: str) -> None: - with Path(file).open("wb") as f: - marshal.dump(self.stats, f) - - def create_stats(self) -> None: - self.simulate_cmd_complete() - self.snapshot_stats() - - def snapshot_stats(self) -> None: - self.stats = {} - for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()): - callers = caller_dict.copy() - nc = 0 - for callcnt in callers.values(): - nc += callcnt - self.stats[func] = cc, nc, tt, ct, callers - - def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None: - self.__enter__() - try: - exec(cmd, global_vars, local_vars) # noqa: S102 - finally: - self.__exit__(None, None, None) - return self def main(args: Namespace | None = None) -> ArgumentParser: @@ -853,40 +83,46 @@ def main(args: Namespace | None = None) -> ArgumentParser: if parsed_args.outfile is not None: parsed_args.outfile = Path(parsed_args.outfile).resolve() outfile = parsed_args.outfile - + config, found_config_path = parse_config_file(parsed_args.codeflash_config) + project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path) if len(unknown_args) > 0: - if parsed_args.module: - import runpy - - code = "run_module(modname, run_name='__main__')" - globs = {"run_module": runpy.run_module, "modname": unknown_args[0]} - else: - progname = unknown_args[0] - sys.path.insert(0, str(Path(progname).resolve().parent)) - with io.open_code(progname) as fp: - code = compile(fp.read(), progname, "exec") - spec = importlib.machinery.ModuleSpec(name="__main__", loader=None, origin=progname) - globs = { - "__spec__": spec, - "__file__": spec.origin, - "__name__": spec.name, - "__package__": None, - "__cached__": None, - } try: - tracer = Tracer( - output=parsed_args.outfile, - functions=parsed_args.only_functions, - max_function_count=parsed_args.max_function_count, - timeout=parsed_args.tracer_timeout, - config_file_path=parsed_args.codeflash_config, - command=" ".join(sys.argv), + result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl") + args_dict = { + "result_pickle_file_path": str(result_pickle_file_path), + "output": str(parsed_args.outfile), + "functions": parsed_args.only_functions, + "disable": False, + "project_root": str(project_root), + "max_function_count": parsed_args.max_function_count, + "timeout": parsed_args.tracer_timeout, + "command": " ".join(sys.argv), + "progname": unknown_args[0], + "config": config, + "module": parsed_args.module, + } + + subprocess.run( + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "tracing" / "tracing_new_process.py", + *sys.argv, + json.dumps(args_dict), + ], + cwd=Path.cwd(), + check=False, ) - tracer.runctx(code, globs, None) - replay_test_path = tracer.replay_test_file_path - if not parsed_args.trace_only and replay_test_path is not None: - del tracer + try: + with result_pickle_file_path.open(mode="rb") as f: + data = pickle.load(f) + except Exception: + console.print("❌ Failed to trace. Exiting...") + sys.exit(1) + finally: + result_pickle_file_path.unlink(missing_ok=True) + replay_test_path = data["replay_test_file_path"] + if not parsed_args.trace_only and replay_test_path is not None: from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO from codeflash.cli_cmds.console import paneled_text diff --git a/codeflash/tracing/tracing_new_process.py b/codeflash/tracing/tracing_new_process.py new file mode 100644 index 000000000..fa3ae2f9c --- /dev/null +++ b/codeflash/tracing/tracing_new_process.py @@ -0,0 +1,853 @@ +from __future__ import annotations + +import contextlib +import datetime +import importlib +import io +import json +import os +import pickle +import re +import sqlite3 +import sys +import threading +import time +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, ClassVar + +from rich.align import Align +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from codeflash.cli_cmds.console import console +from codeflash.picklepatch.pickle_patcher import PicklePatcher +from codeflash.tracing.tracing_utils import FunctionModules, filter_files_optimized, module_name_from_file_path + +if TYPE_CHECKING: + from types import FrameType, TracebackType + + +class FakeCode: + def __init__(self, filename: str, line: int, name: str) -> None: + self.co_filename = filename + self.co_line = line + self.co_name = name + self.co_firstlineno = 0 + + def __repr__(self) -> str: + return repr((self.co_filename, self.co_line, self.co_name, None)) + + +class FakeFrame: + def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None: + self.f_code = code + self.f_back = prior + self.f_locals: dict = {} + + +# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. +class Tracer: + """Use this class as a 'with' context manager to trace a function call. + + Traces function calls, input arguments, and profiling info. + """ + + def __init__( + self, + config: dict, + result_pickle_file_path: Path, + output: str = "codeflash.trace", + functions: list[str] | None = None, + disable: bool = False, # noqa: FBT001, FBT002 + project_root: Path | None = None, + max_function_count: int = 256, + timeout: int | None = None, # seconds + command: str = "", + ) -> None: + """Use this class to trace function calls. + + :param output: The path to the output trace file + :param functions: List of functions to trace. If None, trace all functions + :param disable: Disable the tracer if True + :param max_function_count: Maximum number of times to trace one function + :param timeout: Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing + stops and normal execution continues. If this is None then no timeout applies + :param command: The command that initiated the tracing (for metadata storage) + """ + if functions is None: + functions = [] + if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": + console.rule( + "Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red" + ) + disable = True + self.disable = disable + self._db_lock: threading.Lock | None = None + if self.disable: + return + if sys.getprofile() is not None or sys.gettrace() is not None: + console.print( + "WARNING - Codeflash: Another profiler, debugger or coverage tool is already running. " + "Please disable it before starting the Codeflash Tracer, both can't run. Codeflash Tracer is DISABLED." + ) + self.disable = True + return + + self._db_lock = threading.Lock() + + self.con = None + self.output_file = Path(output).resolve() + self.functions = functions + self.function_modules: list[FunctionModules] = [] + self.function_count = defaultdict(int) + self.current_file_path = Path(__file__).resolve() + self.ignored_qualified_functions = { + f"{self.current_file_path}:Tracer.__exit__", + f"{self.current_file_path}:Tracer.__enter__", + } + self.max_function_count = max_function_count + self.config = config + self.project_root = project_root + console.rule(f"Project Root: {self.project_root}", style="bold blue") + self.ignored_functions = {"", "", "", "", "", ""} + + self.sanitized_filename = self.sanitize_to_filename(command) + self.result_pickle_file_path = result_pickle_file_path + + assert timeout is None or timeout > 0, "Timeout should be greater than 0" + self.timeout = timeout + self.next_insert = 1000 + self.trace_count = 0 + self.path_cache = {} # Cache for resolved file paths + + # Profiler variables + self.bias = 0 # calibration constant + self.timings = {} + self.cur = None + self.start_time = None + self.timer = time.process_time_ns + self.total_tt = 0 + self.simulate_call("profiler") + assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file" + self.t = self.timer() + + # Store command information for metadata table + self.command = command + + def __enter__(self) -> None: + if self.disable: + return + if getattr(Tracer, "used_once", False): + console.print( + "Codeflash: Tracer can only be used once per program run. " + "Please only enable the Tracer once. Skipping tracing this section." + ) + self.disable = True + return + Tracer.used_once = True + + if Path(self.output_file).exists(): + console.rule("Removing existing trace file", style="bold red") + console.rule() + Path(self.output_file).unlink(missing_ok=True) + + self.con = sqlite3.connect(self.output_file, check_same_thread=False) + cur = self.con.cursor() + cur.execute("""PRAGMA synchronous = OFF""") + cur.execute("""PRAGMA journal_mode = WAL""") + # TODO: Check out if we need to export the function test name as well + cur.execute( + "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" + ) + + # Create metadata table to store command information + cur.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + + # Store command metadata + cur.execute("INSERT INTO metadata VALUES (?, ?)", ("command", self.command)) + cur.execute("INSERT INTO metadata VALUES (?, ?)", ("program_name", self.sanitized_filename)) + cur.execute( + "INSERT INTO metadata VALUES (?, ?)", + ("functions_filter", json.dumps(self.functions) if self.functions else None), + ) + cur.execute( + "INSERT INTO metadata VALUES (?, ?)", + ("timestamp", datetime.datetime.now(datetime.timezone.utc).isoformat()), + ) + cur.execute("INSERT INTO metadata VALUES (?, ?)", ("project_root", str(self.project_root))) + console.rule("Codeflash: Traced Program Output Begin", style="bold blue") + frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 + self.dispatch["call"](self, frame, 0) + self.start_time = time.time() + sys.setprofile(self.trace_callback) + threading.setprofile(self.trace_callback) + + def __exit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> None: + if self.disable or self._db_lock is None: + return + sys.setprofile(None) + threading.setprofile(None) + + with self._db_lock: + if self.con is None: + return + + self.con.commit() # Commit any pending from tracer_logic + console.rule("Codeflash: Traced Program Output End", style="bold blue") + self.create_stats() # This calls snapshot_stats which uses self.timings + + cur = self.con.cursor() + cur.execute( + "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " + "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " + "cumulative_time_ns INTEGER, callers BLOB)" + ) + # self.stats is populated by snapshot_stats() called within create_stats() + # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data + # For now, assuming self.stats is primarily in-memory after create_stats() + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] + cur.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + str(Path(func[0]).resolve()), + func[1], + func[2], + func[3], + cc, + nc, + tt, + ct, + json.dumps(remapped_callers), + ), + ) + self.con.commit() + + self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory + self.print_stats("tottime") # Uses self.stats, prints to console + + cur = self.con.cursor() # New cursor + cur.execute("CREATE TABLE total_time (time_ns INTEGER)") + cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) + self.con.commit() + self.con.close() + self.con = None # Mark connection as closed + + # filter any functions where we did not capture the return + self.function_modules = [ + function + for function in self.function_modules + if self.function_count[ + str(function.file_name) + + ":" + + (function.class_name + "." if function.class_name else "") + + function.function_name + ] + > 0 + ] + + # These modules have been imported here now the tracer is done. It is safe to import codeflash and external modules here + + import isort + + from codeflash.tracing.replay_test import create_trace_replay_test + from codeflash.verification.verification_utils import get_test_file_path + + replay_test = create_trace_replay_test( + trace_file=self.output_file, + functions=self.function_modules, + test_framework=self.config["test_framework"], + max_run_count=self.max_function_count, + ) + function_path = "_".join(self.functions) if self.functions else self.sanitized_filename + test_file_path = get_test_file_path( + test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" + ) + replay_test = isort.code(replay_test) + + with Path(test_file_path).open("w", encoding="utf8") as file: + file.write(replay_test) + self.replay_test_file_path = test_file_path + + console.print( + f"Codeflash: Traced {self.trace_count} function calls successfully and replay test created at - {test_file_path}", + crop=False, + soft_wrap=False, + overflow="ignore", + ) + pickle_data = {"replay_test_file_path": self.replay_test_file_path} + import pickle + + with self.result_pickle_file_path.open("wb") as file: + pickle.dump(pickle_data, file) + + def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 + if event != "call": + return + if None is not self.timeout and (time.time() - self.start_time) > self.timeout: + sys.setprofile(None) + threading.setprofile(None) + console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.") + return + if self.disable or self._db_lock is None or self.con is None: + return + + code = frame.f_code + + # Check function name first before resolving path + if code.co_name in self.ignored_functions: + return + + # Now resolve file path only if we need it + co_filename = code.co_filename + if co_filename in self.path_cache: + file_name = self.path_cache[co_filename] + else: + file_name = Path(co_filename).resolve() + self.path_cache[co_filename] = file_name + # TODO : It currently doesn't log the last return call from the first function + + if not file_name.is_relative_to(self.project_root): + return + if not file_name.exists(): + return + if self.functions and code.co_name not in self.functions: + return + class_name = None + arguments = frame.f_locals + try: + self_arg = arguments.get("self") + if self_arg is not None: + try: + class_name = self_arg.__class__.__name__ + except AttributeError: + cls_arg = arguments.get("cls") + if cls_arg is not None: + with contextlib.suppress(AttributeError): + class_name = cls_arg.__name__ + else: + cls_arg = arguments.get("cls") + if cls_arg is not None: + with contextlib.suppress(AttributeError): + class_name = cls_arg.__name__ + except: # noqa: E722 + # someone can override the getattr method and raise an exception. I'm looking at you wrapt + return + + # Extract class name from co_qualname for static methods that lack self/cls + if class_name is None and "." in getattr(code, "co_qualname", ""): + qualname_parts = code.co_qualname.split(".") + if len(qualname_parts) >= 2: + class_name = qualname_parts[-2] + + try: + function_qualified_name = f"{file_name}:{code.co_qualname}" + except AttributeError: + function_qualified_name = f"{file_name}:{(class_name + '.' if class_name else '')}{code.co_name}" + if function_qualified_name in self.ignored_qualified_functions: + return + if function_qualified_name not in self.function_count: + # seeing this function for the first time + self.function_count[function_qualified_name] = 1 + file_valid = filter_files_optimized( + file_path=file_name, + tests_root=Path(self.config["tests_root"]), + ignore_paths=[Path(p) for p in self.config["ignore_paths"]], + module_root=Path(self.config["module_root"]), + ) + if not file_valid: + # we don't want to trace this function because it cannot be optimized + self.ignored_qualified_functions.add(function_qualified_name) + return + self.function_modules.append( + FunctionModules( + function_name=code.co_name, + file_name=file_name, + module_name=module_name_from_file_path(file_name, project_root_path=self.project_root), + class_name=class_name, + line_no=code.co_firstlineno, + ) + ) + else: + self.function_count[function_qualified_name] += 1 + if self.function_count[function_qualified_name] >= self.max_function_count: + self.ignored_qualified_functions.add(function_qualified_name) + return + + # TODO: Also check if this function arguments are unique from the values logged earlier + + with self._db_lock: + # Check connection again inside lock, in case __exit__ closed it. + if self.con is None: + return + + cur = self.con.cursor() + + t_ns = time.perf_counter_ns() + original_recursion_limit = sys.getrecursionlimit() + try: + # pickling can be a recursive operator, so we need to increase the recursion limit + sys.setrecursionlimit(10000) + # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class + # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory + # leaks, bad references or side effects when unpickling. + arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals + if class_name and code.co_name == "__init__" and "self" in arguments_copy: + del arguments_copy["self"] + local_vars = PicklePatcher.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL) + sys.setrecursionlimit(original_recursion_limit) + except Exception: + self.function_count[function_qualified_name] -= 1 + sys.setrecursionlimit(original_recursion_limit) + return + + cur.execute( + "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", + ( + event, + code.co_name, + class_name, + str(file_name), + frame.f_lineno, + frame.f_back.__hash__(), + t_ns, + local_vars, + ), + ) + self.trace_count += 1 + self.next_insert -= 1 + if self.next_insert == 0: + self.next_insert = 1000 + self.con.commit() + + def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: + # profiler section + timer = self.timer + t = timer() - self.t - self.bias + if event == "c_call": + self.c_func_name = arg.__name__ + + prof_success = bool(self.dispatch[event](self, frame, t)) + # tracer section + self.tracer_logic(frame, event) + # measure the time as the last thing before return + if prof_success: + self.t = timer() + else: + self.t = timer() - t # put back unrecorded delta + + def trace_dispatch_call(self, frame: FrameType, t: int) -> int: + """Handle call events in the profiler.""" + try: + # In multi-threaded contexts, we need to be more careful about frame comparisons + if self.cur and frame.f_back is not self.cur[-2]: + # This happens when we're in a different thread + rpt, rit, ret, rfn, rframe, rcur = self.cur + + # Only attempt to handle the frame mismatch if we have a valid rframe + if ( + not isinstance(rframe, FakeFrame) + and hasattr(rframe, "f_back") + and hasattr(frame, "f_back") + and rframe.f_back is frame.f_back + ): + self.trace_dispatch_return(rframe, 0) + + # Get function information + fcode = frame.f_code + arguments = frame.f_locals + class_name = None + try: + if ( + "self" in arguments + and hasattr(arguments["self"], "__class__") + and hasattr(arguments["self"].__class__, "__name__") + ): + class_name = arguments["self"].__class__.__name__ + elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): + class_name = arguments["cls"].__name__ + except Exception: # noqa: S110 + pass + + fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns + 1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 # noqa: TRY300 + except Exception: + # Handle any errors gracefully + return 0 + + def trace_dispatch_exception(self, frame: FrameType, t: int) -> int: + rpt, rit, ret, rfn, rframe, rcur = self.cur + if (rframe is not frame) and rcur: + return self.trace_dispatch_return(rframe, t) + self.cur = rpt, rit + t, ret, rfn, rframe, rcur + return 1 + + def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int: + fn = ("", 0, self.c_func_name, None) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns + 1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 + + def trace_dispatch_return(self, frame: FrameType, t: int) -> int: + if not self.cur or not self.cur[-2]: + return 0 + + # In multi-threaded environments, frames can get mismatched + if frame is not self.cur[-2]: + # Don't assert in threaded environments - frames can legitimately differ + if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: + self.trace_dispatch_return(self.cur[-2], 0) + else: + # We're in a different thread or context, can't continue with this frame + return 0 + # Prefix "r" means part of the Returning or exiting frame. + # Prefix "p" means part of the Previous or Parent or older frame. + + rpt, rit, ret, rfn, frame, rcur = self.cur + + # Guard against invalid rcur (w threading) + if not rcur: + return 0 + + rit = rit + t + frame_total = rit + ret + + ppt, pit, pet, pfn, pframe, pcur = rcur + self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur + + timings = self.timings + if rfn not in timings: + # w threading, rfn can be missing + timings[rfn] = 0, 0, 0, 0, {} + cc, ns, tt, ct, callers = timings[rfn] + if not ns: + # This is the only occurrence of the function on the stack. + # Else this is a (directly or indirectly) recursive call, and + # its cumulative time will get updated when the topmost call to + # it returns. + ct = ct + frame_total + cc = cc + 1 + + if pfn in callers: + # Increment call count between these functions + callers[pfn] = callers[pfn] + 1 + # Note: This tracks stats such as the amount of time added to ct + # of this specific call, and the contribution to cc + # courtesy of this call. + else: + callers[pfn] = 1 + + timings[rfn] = cc, ns - 1, tt + rit, ct, callers + + return 1 + + dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { + "call": trace_dispatch_call, + "exception": trace_dispatch_exception, + "return": trace_dispatch_return, + "c_call": trace_dispatch_c_call, + "c_exception": trace_dispatch_return, # the C function returned + "c_return": trace_dispatch_return, + } + + def simulate_call(self, name: str) -> None: + code = FakeCode("profiler", 0, name) + pframe = self.cur[-2] if self.cur else None + frame = FakeFrame(code, pframe) + self.dispatch["call"](self, frame, 0) + + def simulate_cmd_complete(self) -> None: + get_time = self.timer + t = get_time() - self.t + while self.cur[-1]: + # We *can* cause assertion errors here if + # dispatch_trace_return checks for a frame match! + self.dispatch["return"](self, self.cur[-2], t) + t = 0 + self.t = get_time() - t + + def print_stats(self, sort: str | int | tuple = -1) -> None: + if not self.stats: + console.print("Codeflash: No stats available to print") + self.total_tt = 0 + return + + if not isinstance(sort, tuple): + sort = (sort,) + + # First, convert stats to make them pstats-compatible + try: + # Initialize empty collections for pstats + self.files = [] + self.top_level = [] + + # Create entirely new dictionaries instead of modifying existing ones + new_stats = {} + new_timings = {} + + # Convert stats dictionary + stats_items = list(self.stats.items()) + for func, stats_data in stats_items: + try: + # Make sure we have 5 elements in stats_data + if len(stats_data) != 5: + console.print(f"Skipping malformed stats data for {func}: {stats_data}") + continue + + cc, nc, tt, ct, callers = stats_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func # Keep as is if already in correct format + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + console.print(f"Unexpected caller format: {caller_func}") + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + # Store with new format + new_stats[new_func] = (cc, nc, tt, ct, new_callers) + except Exception as e: + console.print(f"Error converting stats for {func}: {e}") + continue + + timings_items = list(self.timings.items()) + for func, timing_data in timings_items: + try: + if len(timing_data) != 5: + console.print(f"Skipping malformed timing data for {func}: {timing_data}") + continue + + cc, ns, tt, ct, callers = timing_data + + if len(func) == 4: + file_name, line_num, func_name, class_name = func + new_func_name = f"{class_name}.{func_name}" if class_name else func_name + new_func = (file_name, line_num, new_func_name) + else: + new_func = func + + new_callers = {} + callers_items = list(callers.items()) + for caller_func, count in callers_items: + if isinstance(caller_func, tuple): + if len(caller_func) == 4: + caller_file, caller_line, caller_name, caller_class = caller_func + caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name + new_caller_func = (caller_file, caller_line, caller_new_name) + else: + new_caller_func = caller_func + else: + console.print(f"Unexpected caller format: {caller_func}") + new_caller_func = str(caller_func) + + new_callers[new_caller_func] = count + + new_timings[new_func] = (cc, ns, tt, ct, new_callers) + except Exception as e: + console.print(f"Error converting timings for {func}: {e}") + continue + + self.stats = new_stats + self.timings = new_timings + + self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values()) + + total_calls = sum(cc for cc, _, _, _, _ in self.stats.values()) + total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values()) + + summary = Text.assemble( + f"{total_calls:,} function calls ", + ("(" + f"{total_primitive:,} primitive calls" + ")", "dim"), + f" in {self.total_tt / 1e6:.3f}milliseconds", + ) + + console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False))) + + table = Table( + show_header=True, + header_style="bold magenta", + border_style="blue", + title="[bold]Function Profile[/bold] (ordered by internal time)", + title_style="cyan", + caption=f"Showing top {min(25, len(self.stats))} of {len(self.stats)} functions", + ) + + table.add_column("Calls", justify="right", style="green", width=10) + table.add_column("Time (ms)", justify="right", style="cyan", width=10) + table.add_column("Per Call", justify="right", style="cyan", width=10) + table.add_column("Cum (ms)", justify="right", style="yellow", width=10) + table.add_column("Cum/Call", justify="right", style="yellow", width=10) + table.add_column("Function", style="blue") + + sorted_stats = sorted( + ((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3), + key=lambda x: x[1][2], # Sort by tt (internal time) + reverse=True, + )[:25] # Limit to top 25 + + # Format and add each row to the table + for func, (cc, nc, tt, ct, _) in sorted_stats: + filename, lineno, funcname = func + + # Format calls - show recursive format if different + calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}" + + # Convert to milliseconds + tt_ms = tt / 1e6 + ct_ms = ct / 1e6 + + # Calculate per-call times + per_call = tt_ms / cc if cc > 0 else 0 + cum_per_call = ct_ms / nc if nc > 0 else 0 + base_filename = Path(filename).name + file_link = f"[link=file://{filename}]{base_filename}[/link]" + + table.add_row( + calls_str, + f"{tt_ms:.3f}", + f"{per_call:.3f}", + f"{ct_ms:.3f}", + f"{cum_per_call:.3f}", + f"{funcname} [dim]({file_link}:{lineno})[/dim]", + ) + + console.print(Align.center(table)) + + except Exception as e: + console.print(f"[bold red]Error in stats processing:[/bold red] {e}") + console.print(f"Traced {self.trace_count:,} function calls") + self.total_tt = 0 + + def make_pstats_compatible(self) -> None: + # delete the extra class_name item from the function tuple + self.files = [] + self.top_level = [] + new_stats = {} + for func, (cc, ns, tt, ct, callers) in self.stats.items(): + new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} + new_stats[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) + new_timings = {} + for func, (cc, ns, tt, ct, callers) in self.timings.items(): + new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} + new_timings[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) + self.stats = new_stats + self.timings = new_timings + + def dump_stats(self, file: str) -> None: + import marshal + + with Path(file).open("wb") as f: + marshal.dump(self.stats, f) + + def create_stats(self) -> None: + self.simulate_cmd_complete() + self.snapshot_stats() + + def snapshot_stats(self) -> None: + self.stats = {} + for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()): + callers = caller_dict.copy() + nc = 0 + for callcnt in callers.values(): + nc += callcnt + self.stats[func] = cc, nc, tt, ct, callers + + def sanitize_to_filename(self, arg: str) -> str: + # Replace newlines with underscores + arg = arg.replace("\n", "_").replace("\r", "_") + + # Replace contiguous whitespace (including tabs and multiple spaces) with a single underscore + # Limit to 5 whitespace splits + parts = re.split(r"\s+", arg) + if len(parts) > 5: + parts = parts[:5] + + arg = "_".join(parts) + + # Remove all characters that are not alphanumeric, underscore, or dot + arg = re.sub(r"[^\w._]", "", arg) + + # Avoid filenames starting or ending with a dot or underscore + arg = arg.strip("._") + + # Limit to 100 characters + arg = arg[:100] + + # Fallback if resulting name is empty + return arg or "untitled" + + def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None: + self.__enter__() + try: + exec(cmd, global_vars, local_vars) # noqa: S102 + finally: + self.__exit__(None, None, None) + return self + + +if __name__ == "__main__": + args_dict = json.loads(sys.argv[-1]) + sys.argv = sys.argv[1:-1] + if args_dict["module"]: + import runpy + + code = "run_module(modname, run_name='__main__')" + globs = {"run_module": runpy.run_module, "modname": args_dict["progname"]} + else: + sys.path.insert(0, str(Path(args_dict["progname"]).resolve().parent)) + with io.open_code(args_dict["progname"]) as fp: + code = compile(fp.read(), args_dict["progname"], "exec") + spec = importlib.machinery.ModuleSpec(name="__main__", loader=None, origin=args_dict["progname"]) + globs = { + "__spec__": spec, + "__file__": spec.origin, + "__name__": spec.name, + "__package__": None, + "__cached__": None, + } + args_dict["config"]["module_root"] = Path(args_dict["config"]["module_root"]) + args_dict["config"]["tests_root"] = Path(args_dict["config"]["tests_root"]) + tracer = Tracer( + config=args_dict["config"], + output=Path(args_dict["output"]), + functions=args_dict["functions"], + max_function_count=args_dict["max_function_count"], + timeout=args_dict["timeout"], + command=args_dict["command"], + disable=args_dict["disable"], + result_pickle_file_path=Path(args_dict["result_pickle_file_path"]), + project_root=Path(args_dict["project_root"]), + ) + tracer.runctx(code, globs, None) diff --git a/codeflash/tracing/tracing_utils.py b/codeflash/tracing/tracing_utils.py index 2e7096963..0114ea01c 100644 --- a/codeflash/tracing/tracing_utils.py +++ b/codeflash/tracing/tracing_utils.py @@ -1,15 +1,69 @@ from __future__ import annotations +import site +from dataclasses import dataclass +from functools import cache from pathlib import Path -from typing import Optional +from typing import Optional, cast -from pydantic import dataclasses +import git -@dataclasses.dataclass +# This can't be pydantic dataclass because then conflicts with the logfire pytest plugin +# for pydantic tracing. We want to not use pydantic in the tracing code. +@dataclass class FunctionModules: function_name: str file_name: Path module_name: str class_name: Optional[str] = None line_no: Optional[int] = None + + +def path_belongs_to_site_packages(file_path: Path) -> bool: + site_packages = [Path(p) for p in site.getsitepackages()] + return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages) + + +def is_git_repo(file_path: str) -> bool: + try: + git.Repo(file_path, search_parent_directories=True) + return True # noqa: TRY300 + except git.InvalidGitRepositoryError: + return False + + +@cache +def ignored_submodule_paths(module_root: str) -> list[Path]: + if is_git_repo(module_root): + git_repo = git.Repo(module_root, search_parent_directories=True) + working_tree_dir = cast("Path", git_repo.working_tree_dir) + return [Path(working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules] + return [] + + +def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str: + relative_path = file_path.relative_to(project_root_path) + return relative_path.with_suffix("").as_posix().replace("/", ".") + + +def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list[Path], module_root: Path) -> bool: + """Optimized version of the filter_functions function above. + + Takes in file paths and returns the count of files that are to be optimized. + """ + submodule_paths = None + if file_path.is_relative_to(tests_root): + return False + if file_path in ignore_paths or any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths): + return False + if path_belongs_to_site_packages(file_path): + return False + if not file_path.is_relative_to(module_root): + return False + if submodule_paths is None: + submodule_paths = ignored_submodule_paths(module_root) + return not ( + file_path in submodule_paths + or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) + ) diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index be8bc09da..3fe72a764 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -84,7 +84,7 @@ def generate_concolic_tests( test_framework=args.test_framework, pytest_cmd=args.pytest_cmd, ) - function_to_concolic_tests, num_discovered_concolic_tests = discover_unit_tests(concolic_test_cfg) + function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg) logger.info( f"Created {num_discovered_concolic_tests} " f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} " diff --git a/tests/test_test_runner.py b/tests/test_test_runner.py index 5dc6df678..8c879b823 100644 --- a/tests/test_test_runner.py +++ b/tests/test_test_runner.py @@ -98,7 +98,7 @@ def test_sort(): assert results[0].did_pass, "Test did not pass as expected" result_file.unlink(missing_ok=True) - code = """import torch + code = """import torch_does_not_exist def sorter(arr): print(torch.ones(1)) arr.sort() @@ -143,5 +143,5 @@ def test_sort(): test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process ) match = ImportErrorPattern.search(process.stdout).group() - assert match == "ModuleNotFoundError: No module named 'torch'" + assert match == "ModuleNotFoundError: No module named 'torch_does_not_exist'" result_file.unlink(missing_ok=True) diff --git a/tests/test_tracer.py b/tests/test_tracer.py index 8708ebd32..f9c2ae23a 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -1,4 +1,6 @@ import contextlib +import dataclasses +import pickle import sqlite3 import sys import tempfile @@ -10,8 +12,8 @@ from unittest.mock import patch import pytest - -from codeflash.tracer import FakeCode, FakeFrame, Tracer +from codeflash.code_utils.config_parser import parse_config_file +from codeflash.tracing.tracing_new_process import FakeCode, FakeFrame, Tracer class TestFakeCode: @@ -46,15 +48,24 @@ def test_fake_frame_with_prior(self) -> None: assert fake_frame2.f_back == fake_frame1 +@dataclasses.dataclass +class TraceConfig: + trace_file: Path + trace_config: dict[str, Any] + result_pickle_file_path: Path + project_root: Path + command: str + + class TestTracer: @pytest.fixture - def temp_config_file(self) -> Generator[Path, None, None]: + def trace_config(self) -> Generator[Path, None, None]: """Create a temporary pyproject.toml config file.""" # Create a temporary directory structure temp_dir = Path(tempfile.mkdtemp()) tests_dir = temp_dir / "tests" tests_dir.mkdir(exist_ok=True) - + # Use the current working directory as module root so test files are included current_dir = Path.cwd() @@ -67,18 +78,25 @@ def temp_config_file(self) -> Generator[Path, None, None]: ignore-paths = [] """) config_path = Path(f.name) - yield config_path - import shutil - shutil.rmtree(temp_dir, ignore_errors=True) - - @pytest.fixture - def temp_trace_file(self) -> Generator[Path, None, None]: - """Create a temporary trace file path.""" with tempfile.NamedTemporaryFile(suffix=".trace", delete=False) as f: trace_path = Path(f.name) trace_path.unlink(missing_ok=True) # Remove the file, we just want the path - yield trace_path + replay_test_pkl_path = temp_dir / "replay_test.pkl" + config, found_config_path = parse_config_file(config_path) + trace_config = TraceConfig( + trace_file=trace_path, + trace_config=config, + result_pickle_file_path=replay_test_pkl_path, + project_root=current_dir, + command="pytest random", + ) + + yield trace_config + import shutil + + shutil.rmtree(temp_dir, ignore_errors=True) trace_path.unlink(missing_ok=True) + replay_test_pkl_path.unlink(missing_ok=True) @pytest.fixture(autouse=True) def reset_tracer_state(self) -> Generator[None, None, None]: @@ -91,78 +109,89 @@ def reset_tracer_state(self) -> Generator[None, None, None]: if hasattr(Tracer, "used_once"): delattr(Tracer, "used_once") - def test_tracer_disabled_by_environment(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_disabled_by_environment(self, trace_config: TraceConfig) -> None: """Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set.""" with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}): tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) assert tracer.disable is True - def test_tracer_disabled_with_existing_profiler(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_disabled_with_existing_profiler(self, trace_config: TraceConfig) -> None: """Test that tracer is disabled when another profiler is running.""" + def dummy_profiler(_frame: object, _event: str, _arg: object) -> object: return dummy_profiler sys.setprofile(dummy_profiler) try: tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) assert tracer.disable is True finally: sys.setprofile(None) - def test_tracer_initialization_normal(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_initialization_normal(self, trace_config: TraceConfig) -> None: """Test normal tracer initialization.""" tracer = Tracer( - output=str(temp_trace_file), + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, functions=["test_func"], max_function_count=128, timeout=10, - config_file_path=temp_config_file ) assert tracer.disable is False - assert tracer.output_file == temp_trace_file.resolve() + assert tracer.output_file == trace_config.trace_file.resolve() assert tracer.functions == ["test_func"] assert tracer.max_function_count == 128 assert tracer.timeout == 10 assert hasattr(tracer, "_db_lock") - assert getattr(tracer, "_db_lock") is not None + assert tracer._db_lock is not None - def test_tracer_timeout_validation(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_timeout_validation(self, trace_config: TraceConfig) -> None: with pytest.raises(AssertionError): Tracer( - output=str(temp_trace_file), + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, timeout=0, - config_file_path=temp_config_file ) with pytest.raises(AssertionError): Tracer( - output=str(temp_trace_file), + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, timeout=-5, - config_file_path=temp_config_file ) - def test_tracer_context_manager_disabled(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_context_manager_disabled(self, trace_config: TraceConfig) -> None: tracer = Tracer( - output=str(temp_trace_file), + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, disable=True, - config_file_path=temp_config_file ) with tracer: pass - assert not temp_trace_file.exists() + assert not trace_config.trace_file.exists() - - - def test_tracer_function_filtering(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_function_filtering(self, trace_config: TraceConfig) -> None: """Test that tracer respects function filtering.""" if hasattr(Tracer, "used_once"): delattr(Tracer, "used_once") @@ -174,17 +203,19 @@ def other_function() -> int: return 24 tracer = Tracer( - output=str(temp_trace_file), + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, functions=["test_function"], - config_file_path=temp_config_file ) with tracer: test_function() other_function() - if temp_trace_file.exists(): - con = sqlite3.connect(temp_trace_file) + if trace_config.trace_file.exists(): + con = sqlite3.connect(trace_config.trace_file) cursor = con.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") @@ -197,38 +228,41 @@ def other_function() -> int: con.close() - - def test_tracer_max_function_count(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_max_function_count(self, trace_config: TraceConfig) -> None: def counting_function(n: int) -> int: return n * 2 tracer = Tracer( - output=str(temp_trace_file), + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, max_function_count=3, - config_file_path=temp_config_file ) with tracer: for i in range(5): counting_function(i) - + assert tracer.trace_count <= 3, "Tracer should limit the number of traced functions to max_function_count" - def test_tracer_timeout_functionality(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_timeout_functionality(self, trace_config: TraceConfig) -> None: def slow_function() -> str: time.sleep(0.1) return "done" tracer = Tracer( - output=str(temp_trace_file), + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, timeout=1, # 1 second timeout - config_file_path=temp_config_file ) with tracer: slow_function() - def test_tracer_threading_safety(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_threading_safety(self, trace_config: TraceConfig) -> None: """Test that tracer works correctly with threading.""" results = [] @@ -236,8 +270,10 @@ def thread_function(n: int) -> None: results.append(n * 2) tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) with tracer: @@ -252,30 +288,36 @@ def thread_function(n: int) -> None: assert len(results) == 3 - def test_simulate_call(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_simulate_call(self, trace_config: TraceConfig) -> None: """Test the simulate_call method.""" tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) tracer.simulate_call("test_simulation") - def test_simulate_cmd_complete(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_simulate_cmd_complete(self, trace_config: TraceConfig) -> None: """Test the simulate_cmd_complete method.""" tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) tracer.simulate_call("test") tracer.simulate_cmd_complete() - def test_runctx_method(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_runctx_method(self, trace_config: TraceConfig) -> None: """Test the runctx method for executing code with tracing.""" tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) global_vars = {"x": 10} @@ -286,12 +328,12 @@ def test_runctx_method(self, temp_config_file: Path, temp_trace_file: Path) -> N assert result == tracer assert local_vars["y"] == 20 - def test_tracer_handles_class_methods(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_handles_class_methods(self, trace_config: TraceConfig) -> None: """Test that tracer correctly handles class methods.""" # Ensure tracer hasn't been used yet in this test if hasattr(Tracer, "used_once"): delattr(Tracer, "used_once") - + class TestClass: def instance_method(self) -> str: return "instance" @@ -305,8 +347,10 @@ def static_method() -> str: return "static" tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) with tracer: @@ -314,22 +358,20 @@ def static_method() -> str: instance_result = obj.instance_method() class_result = TestClass.class_method() static_result = TestClass.static_method() - - - if temp_trace_file.exists(): - con = sqlite3.connect(temp_trace_file) + if trace_config.trace_file.exists(): + con = sqlite3.connect(trace_config.trace_file) cursor = con.cursor() - + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") if cursor.fetchone(): # Query for all function calls cursor.execute("SELECT function, classname FROM function_calls") calls = cursor.fetchall() - + function_names = [call[0] for call in calls] class_names = [call[1] for call in calls if call[1] is not None] - + assert "instance_method" in function_names assert "class_method" in function_names assert "static_method" in function_names @@ -338,34 +380,33 @@ def static_method() -> str: pytest.fail("No function_calls table found in trace file") con.close() - - - - - def test_tracer_handles_exceptions_gracefully(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def test_tracer_handles_exceptions_gracefully(self, trace_config: TraceConfig) -> None: """Test that tracer handles exceptions in traced code gracefully.""" + def failing_function() -> None: raise ValueError("Test exception") tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) with tracer, contextlib.suppress(ValueError): failing_function() - - - - - def test_tracer_with_complex_arguments(self, temp_config_file: Path, temp_trace_file: Path) -> None: - def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x) -> int: + def test_tracer_with_complex_arguments(self, trace_config: TraceConfig) -> None: + def complex_function( + data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x + ) -> int: return len(data_dict) + len(nested_list) tracer = Tracer( - output=str(temp_trace_file), - config_file_path=temp_config_file + output=str(trace_config.trace_file), + config=trace_config.trace_config, + project_root=trace_config.project_root, + result_pickle_file_path=trace_config.result_pickle_file_path, ) expected_dict = {"key": "value", "nested": {"inner": "data"}} @@ -373,14 +414,13 @@ def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], fu expected_func = lambda x: x * 2 with tracer: - complex_function( - expected_dict, - expected_list, - func_arg=expected_func - ) + complex_function(expected_dict, expected_list, func_arg=expected_func) + assert trace_config.result_pickle_file_path.exists() + pickled = pickle.load(trace_config.result_pickle_file_path.open("rb")) + assert pickled["replay_test_file_path"].exists() - if temp_trace_file.exists(): - con = sqlite3.connect(temp_trace_file) + if trace_config.trace_file.exists(): + con = sqlite3.connect(trace_config.trace_file) cursor = con.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") @@ -388,15 +428,15 @@ def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], fu cursor.execute("SELECT args FROM function_calls WHERE function = 'complex_function'") result = cursor.fetchone() assert result is not None, "Function complex_function should be traced" - + # Deserialize the arguments - import pickle + traced_args = pickle.loads(result[0]) - + assert "data_dict" in traced_args assert "nested_list" in traced_args assert "func_arg" in traced_args - + assert traced_args["data_dict"] == expected_dict assert traced_args["nested_list"] == expected_list assert callable(traced_args["func_arg"]) diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index f985f60fb..101aa2671 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -7,9 +7,9 @@ discover_unit_tests, filter_test_files_by_imports, ) +from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import TestsInFile, TestType from codeflash.verification.verification_utils import TestConfig -from codeflash.discovery.functions_to_optimize import FunctionToOptimize def test_unit_test_discovery_pytest(): @@ -21,7 +21,7 @@ def test_unit_test_discovery_pytest(): test_framework="pytest", tests_project_rootdir=tests_path.parent, ) - tests, _ = discover_unit_tests(test_config) + tests, _, _ = discover_unit_tests(test_config) assert len(tests) > 0 @@ -34,8 +34,8 @@ def test_benchmark_test_discovery_pytest(): test_framework="pytest", tests_project_rootdir=tests_path.parent, ) - tests, _ = discover_unit_tests(test_config) - assert len(tests) == 1 # Should not discover benchmark tests + tests, _, _ = discover_unit_tests(test_config) + assert len(tests) == 1 # Should not discover benchmark tests def test_unit_test_discovery_unittest(): @@ -48,10 +48,11 @@ def test_unit_test_discovery_unittest(): tests_project_rootdir=project_path.parent, ) os.chdir(project_path) - tests, _ = discover_unit_tests(test_config) + tests, _, _ = discover_unit_tests(test_config) # assert len(tests) > 0 # Unittest discovery within a pytest environment does not work + def test_benchmark_unit_test_discovery_pytest(): with tempfile.TemporaryDirectory() as tmpdirname: # Create a dummy test file @@ -86,14 +87,15 @@ def sorter(arr): ) # Discover tests - tests, _ = discover_unit_tests(test_config) + tests, _, _ = discover_unit_tests(test_config) assert len(tests) == 1 - assert 'bubble_sort.sorter' in tests - assert len(tests['bubble_sort.sorter']) == 2 - functions = [test.tests_in_file.test_function for test in tests['bubble_sort.sorter']] - assert 'test_normal_test' in functions - assert 'test_normal_test2' in functions - assert 'test_benchmark_sort' not in functions + assert "bubble_sort.sorter" in tests + assert len(tests["bubble_sort.sorter"]) == 2 + functions = [test.tests_in_file.test_function for test in tests["bubble_sort.sorter"]] + assert "test_normal_test" in functions + assert "test_normal_test2" in functions + assert "test_benchmark_sort" not in functions + def test_discover_tests_pytest_with_temp_dir_root(): with tempfile.TemporaryDirectory() as tmpdirname: @@ -125,14 +127,17 @@ def test_discover_tests_pytest_with_temp_dir_root(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the dummy test file is discovered assert len(discovered_tests) == 1 assert len(discovered_tests["dummy_code.dummy_function"]) == 2 dummy_tests = discovered_tests["dummy_code.dummy_function"] assert all(test.tests_in_file.test_file == test_file_path for test in dummy_tests) - assert {test.tests_in_file.test_function for test in dummy_tests} == {"test_dummy_parametrized_function[True]", "test_dummy_function"} + assert {test.tests_in_file.test_function for test in dummy_tests} == { + "test_dummy_parametrized_function[True]", + "test_dummy_function", + } def test_discover_tests_pytest_with_multi_level_dirs(): @@ -195,13 +200,14 @@ def test_discover_tests_pytest_with_multi_level_dirs(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the test files at all levels are discovered assert len(discovered_tests) == 3 assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path assert ( - next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path + next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file + == level1_test_file_path ) assert ( @@ -285,13 +291,14 @@ def test_discover_tests_pytest_dirs(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the test files at all levels are discovered assert len(discovered_tests) == 4 assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path assert ( - next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path + next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file + == level1_test_file_path ) assert ( next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file @@ -331,11 +338,14 @@ def test_discover_tests_pytest_with_class(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the test class and method are discovered assert len(discovered_tests) == 1 - assert next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file == test_file_path + assert ( + next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file + == test_file_path + ) def test_discover_tests_pytest_with_double_nested_directories(): @@ -369,12 +379,14 @@ def test_discover_tests_pytest_with_double_nested_directories(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the test class and method are discovered assert len(discovered_tests) == 1 assert ( - next(iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])).tests_in_file.test_file + next( + iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"]) + ).tests_in_file.test_file == test_file_path ) @@ -417,7 +429,7 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the test file is discovered and associated with the code file assert len(discovered_tests) == 1 @@ -430,10 +442,7 @@ def test_discover_tests_pytest_with_nested_class(): # Create a code file with a nested class code_file_path = path_obj_tmpdirname / "nested_class_code.py" code_file_content = ( - "class OuterClass:\n" - " class InnerClass:\n" - " def inner_method(self):\n" - " return True\n" + "class OuterClass:\n class InnerClass:\n def inner_method(self):\n return True\n" ) code_file_path.write_text(code_file_content) @@ -456,7 +465,7 @@ def test_discover_tests_pytest_with_nested_class(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the test for the nested class method is discovered assert len(discovered_tests) == 1 @@ -496,7 +505,7 @@ def test_discover_tests_pytest_separate_moduledir(): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Check if the test for the nested class method is discovered assert len(discovered_tests) == 1 @@ -538,7 +547,7 @@ def test_add(self): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 1 @@ -606,7 +615,7 @@ def test_add(self): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 2 @@ -652,7 +661,7 @@ def _test_add(self): # Private test method should not be discovered ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Verify no tests were discovered assert len(discovered_tests) == 0 @@ -704,7 +713,7 @@ def test_add_with_parameters(self): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 1 @@ -712,9 +721,7 @@ def test_add_with_parameters(self): assert len(discovered_tests["calculator.Calculator.add"]) == 1 calculator_test = next(iter(discovered_tests["calculator.Calculator.add"])) assert calculator_test.tests_in_file.test_file == test_file_path - assert ( - calculator_test.tests_in_file.test_function == "test_add_with_parameters" - ) + assert calculator_test.tests_in_file.test_function == "test_add_with_parameters" def test_unittest_discovery_with_pytest_parameterized(): @@ -787,7 +794,7 @@ def test_add_mixed(self, name, a, b, expected): ) # Discover tests - discovered_tests, _ = discover_unit_tests(test_config) + discovered_tests, _, _ = discover_unit_tests(test_config) # Verify the basic structure assert len(discovered_tests) == 2 # Should have tests for both add and multiply @@ -809,10 +816,10 @@ def test_target(): assert target_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_function", "missing_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is True @@ -827,12 +834,11 @@ def test_something(): assert something() is True """ test_file.write_text(test_content) - + target_functions = {"target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - - assert should_process is False + assert should_process is False with tempfile.TemporaryDirectory() as tmpdirname: test_file = Path(tmpdirname) / "test_example.py" @@ -844,13 +850,11 @@ def test_target(): """ test_file.group test_file.write_text(test_content) - + target_functions = {"mymodule.target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - - assert should_process is True - + assert should_process is True with tempfile.TemporaryDirectory() as tmpdirname: test_file = Path(tmpdirname) / "test_example.py" @@ -861,13 +865,12 @@ def test_target(): assert target_function_extended() is True """ test_file.write_text(test_content) - - # Should not match - target_function != target_function_extended + + # Should not match - target_function != target_function_extended target_functions = {"mymodule.target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - - assert should_process is False + assert should_process is False with tempfile.TemporaryDirectory() as tmpdirname: test_file = Path(tmpdirname) / "test_example.py" @@ -879,12 +882,11 @@ def test_something(): assert x == 42 """ test_file.write_text(test_content) - + target_functions = {"mymodule.target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - - assert should_process is False + assert should_process is False with tempfile.TemporaryDirectory() as tmpdirname: test_file = Path(tmpdirname) / "test_example.py" @@ -896,14 +898,13 @@ def test_something(): assert "target_function" in message """ test_file.write_text(test_content) - + target_functions = {"mymodule.target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + # String literals are ast.Constant nodes, not ast.Name nodes, so they don't match assert should_process is False - with tempfile.TemporaryDirectory() as tmpdirname: test_file = Path(tmpdirname) / "test_example.py" test_content = """ @@ -915,13 +916,11 @@ def test_target(): assert other_func() is True """ test_file.write_text(test_content) - + target_functions = {"mymodule.target_function", "othermodule.other_func"} should_process = analyze_imports_in_test_file(test_file, target_functions) - - assert should_process is True - + assert should_process is True def test_analyze_imports_module_import(): @@ -935,10 +934,10 @@ def test_target(): assert mymodule.target_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is True @@ -954,10 +953,10 @@ def test_dynamic(): assert module.target_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is True @@ -971,10 +970,10 @@ def test_builtin_import(): assert module.target_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is True @@ -989,7 +988,7 @@ def test_unrelated(): assert unrelated_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_function", "another_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is False @@ -1005,13 +1004,12 @@ def test_target(): assert some_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_module.some_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) assert should_process is True - def test_analyze_imports_syntax_error(): """Test handling of files with syntax errors.""" with tempfile.TemporaryDirectory() as tmpdirname: @@ -1023,10 +1021,10 @@ def test_target( assert target_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + # Should be conservative with unparseable files assert should_process is True @@ -1034,7 +1032,7 @@ def test_target( def test_filter_test_files_by_imports(): with tempfile.TemporaryDirectory() as tmpdirname: tmpdir = Path(tmpdirname) - + # Create test file that imports target function relevant_test = tmpdir / "test_relevant.py" relevant_test.write_text(""" @@ -1043,7 +1041,7 @@ def test_filter_test_files_by_imports(): def test_target(): assert target_function() is True """) - + # Create test file that doesn't import target function irrelevant_test = tmpdir / "test_irrelevant.py" irrelevant_test.write_text(""" @@ -1052,7 +1050,7 @@ def test_target(): def test_other(): assert other_function() is True """) - + # Create test file with star import (should not be processed) star_test = tmpdir / "test_star.py" star_test.write_text(""" @@ -1061,38 +1059,65 @@ def test_other(): def test_star(): assert something() is True """) - + file_to_test_map = { - relevant_test: [TestsInFile(test_file=relevant_test, test_function="test_target", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], - irrelevant_test: [TestsInFile(test_file=irrelevant_test, test_function="test_other", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], - star_test: [TestsInFile(test_file=star_test, test_function="test_star", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], + relevant_test: [ + TestsInFile( + test_file=relevant_test, + test_function="test_target", + test_class=None, + test_type=TestType.EXISTING_UNIT_TEST, + ) + ], + irrelevant_test: [ + TestsInFile( + test_file=irrelevant_test, + test_function="test_other", + test_class=None, + test_type=TestType.EXISTING_UNIT_TEST, + ) + ], + star_test: [ + TestsInFile( + test_file=star_test, + test_function="test_star", + test_class=None, + test_type=TestType.EXISTING_UNIT_TEST, + ) + ], } - + target_functions = {"target_function"} filtered_map = filter_test_files_by_imports(file_to_test_map, target_functions) - + # Should filter out irrelevant_test assert len(filtered_map) == 1 assert relevant_test in filtered_map assert irrelevant_test not in filtered_map - def test_filter_test_files_no_target_functions(): """Test that filtering is skipped when no target functions are provided.""" with tempfile.TemporaryDirectory() as tmpdirname: tmpdir = Path(tmpdirname) - + test_file = tmpdir / "test_example.py" test_file.write_text("def test_something(): pass") - + file_to_test_map = { - test_file: [TestsInFile(test_file=test_file, test_function="test_something", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)] + test_file: [ + TestsInFile( + test_file=test_file, + test_function="test_something", + test_class=None, + test_type=TestType.EXISTING_UNIT_TEST, + ) + ] } - + # No target functions provided - filtered_map = filter_test_files_by_imports(file_to_test_map, set()) - + filtered_map = filter_test_files_by_imports(file_to_test_map, set()) + # Should return original map unchanged assert filtered_map == file_to_test_map @@ -1101,7 +1126,7 @@ def test_discover_unit_tests_with_import_filtering(): """Test the full discovery process with import filtering.""" with tempfile.TemporaryDirectory() as tmpdirname: tmpdir = Path(tmpdirname) - + # Create a code file code_file = tmpdir / "mycode.py" code_file.write_text(""" @@ -1111,7 +1136,7 @@ def target_function(): def other_function(): return False """) - + # Create relevant test file relevant_test = tmpdir / "test_relevant.py" relevant_test.write_text(""" @@ -1120,7 +1145,7 @@ def other_function(): def test_target(): assert target_function() is True """) - + # Create irrelevant test file irrelevant_test = tmpdir / "test_irrelevant.py" irrelevant_test.write_text(""" @@ -1129,26 +1154,18 @@ def test_target(): def test_other(): assert other_function() is False """) - + # Configure test discovery test_config = TestConfig( - tests_root=tmpdir, - project_root_path=tmpdir, - test_framework="pytest", - tests_project_rootdir=tmpdir.parent, - ) - - all_tests, _ = discover_unit_tests(test_config) - assert len(all_tests) == 2 - - - fto = FunctionToOptimize( - function_name="target_function", - file_path=code_file, - parents=[], + tests_root=tmpdir, project_root_path=tmpdir, test_framework="pytest", tests_project_rootdir=tmpdir.parent ) - filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [fto]}) + all_tests, _, _ = discover_unit_tests(test_config) + assert len(all_tests) == 2 + + fto = FunctionToOptimize(function_name="target_function", file_path=code_file, parents=[]) + + filtered_tests, _, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [fto]}) assert len(filtered_tests) >= 1 assert "mycode.target_function" in filtered_tests @@ -1164,10 +1181,10 @@ def test_conditional(): assert target_function() is True """ test_file.write_text(test_content) - + target_functions = {"target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is True @@ -1186,10 +1203,10 @@ def test_indirect(): assert result is True """ test_file.write_text(test_content) - + target_functions = {"target_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is True @@ -1205,10 +1222,10 @@ def test_aliased(): assert of() is False """ test_file.write_text(test_content) - + target_functions = {"target_function", "missing_function"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is True @@ -1222,31 +1239,32 @@ def test_bubble(): assert sort_function([3,1,2]) == [1,2,3] """ test_file.write_text(test_content) - + target_functions = {"bubble_sort"} should_process = analyze_imports_in_test_file(test_file, target_functions) - + assert should_process is False + def test_discover_unit_tests_filtering_different_modules(): """Test import filtering with test files from completely different modules.""" with tempfile.TemporaryDirectory() as tmpdirname: tmpdir = Path(tmpdirname) - + # Create target code file target_file = tmpdir / "target_module.py" target_file.write_text(""" def target_function(): return True """) - + # Create unrelated code file unrelated_file = tmpdir / "unrelated_module.py" unrelated_file.write_text(""" def unrelated_function(): return False """) - + # Create test file that imports target function relevant_test = tmpdir / "test_target.py" relevant_test.write_text(""" @@ -1255,7 +1273,7 @@ def unrelated_function(): def test_target(): assert target_function() is True """) - + # Create test file that imports unrelated function irrelevant_test = tmpdir / "test_unrelated.py" irrelevant_test.write_text(""" @@ -1264,26 +1282,19 @@ def test_target(): def test_unrelated(): assert unrelated_function() is False """) - + # Configure test discovery test_config = TestConfig( - tests_root=tmpdir, - project_root_path=tmpdir, - test_framework="pytest", - tests_project_rootdir=tmpdir.parent, + tests_root=tmpdir, project_root_path=tmpdir, test_framework="pytest", tests_project_rootdir=tmpdir.parent ) - + # Test without filtering - all_tests, _ = discover_unit_tests(test_config) + all_tests, _, _ = discover_unit_tests(test_config) assert len(all_tests) == 2 # Should find both functions - - fto = FunctionToOptimize( - function_name="target_function", - file_path=target_file, - parents=[], - ) - filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [fto]}) + fto = FunctionToOptimize(function_name="target_function", file_path=target_file, parents=[]) + + filtered_tests, _, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [fto]}) assert len(filtered_tests) == 1 assert "target_module.target_function" in filtered_tests assert "unrelated_module.unrelated_function" not in filtered_tests