diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index 7373b3fca..8642d6558 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -32,10 +32,4 @@ jobs: run: uvx poetry install --with dev - name: Unit tests - run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip" - - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v5 - if: matrix.python-version == '3.12.1' - with: - token: ${{ secrets.CODECOV_TOKEN }} + run: uvx poetry run pytest tests/ --benchmark-skip -m "not ci_skip" \ No newline at end of file diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index c50a0ad49..3443e1882 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -599,7 +599,14 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool: - return any(isinstance(node, ast.Return) for node in ast.walk(function_node)) + # Custom DFS, return True as soon as a Return node is found + stack = [function_node] + while stack: + node = stack.pop() + if isinstance(node, ast.Return): + return True + stack.extend(ast.iter_child_nodes(node)) + return False def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool: diff --git a/codeflash/tracer.py b/codeflash/tracer.py index c06cbe949..9fa1f3290 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -11,6 +11,7 @@ # from __future__ import annotations +import contextlib import importlib.machinery import io import json @@ -99,6 +100,7 @@ def __init__( ) disable = True self.disable = disable + self._db_lock: threading.Lock | None = None if self.disable: return if sys.getprofile() is not None or sys.gettrace() is not None: @@ -108,6 +110,9 @@ def __init__( ) self.disable = True return + + self._db_lock = threading.Lock() + self.con = None self.output_file = Path(output).resolve() self.functions = functions @@ -130,6 +135,7 @@ def __init__( self.timeout = timeout self.next_insert = 1000 self.trace_count = 0 + self.path_cache = {} # Cache for resolved file paths # Profiler variables self.bias = 0 # calibration constant @@ -178,34 +184,55 @@ def __enter__(self) -> None: def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: - if self.disable: + if self.disable or self._db_lock is None: return sys.setprofile(None) - self.con.commit() - console.rule("Codeflash: Traced Program Output End", style="bold blue") - self.create_stats() + threading.setprofile(None) - cur = self.con.cursor() - cur.execute( - "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " - "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " - "cumulative_time_ns INTEGER, callers BLOB)" - ) - for func, (cc, nc, tt, ct, callers) in self.stats.items(): - remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] + with self._db_lock: + if self.con is None: + return + + self.con.commit() # Commit any pending from tracer_logic + console.rule("Codeflash: Traced Program Output End", style="bold blue") + self.create_stats() # This calls snapshot_stats which uses self.timings + + cur = self.con.cursor() cur.execute( - "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - (str(Path(func[0]).resolve()), func[1], func[2], func[3], cc, nc, tt, ct, json.dumps(remapped_callers)), + "CREATE TABLE pstats (filename TEXT, line_number INTEGER, function TEXT, class_name TEXT, " + "call_count_nonrecursive INTEGER, num_callers INTEGER, total_time_ns INTEGER, " + "cumulative_time_ns INTEGER, callers BLOB)" ) - self.con.commit() + # self.stats is populated by snapshot_stats() called within create_stats() + # Ensure self.stats is accessed after create_stats() and within the lock if it involves DB data + # For now, assuming self.stats is primarily in-memory after create_stats() + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + remapped_callers = [{"key": k, "value": v} for k, v in callers.items()] + cur.execute( + "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + str(Path(func[0]).resolve()), + func[1], + func[2], + func[3], + cc, + nc, + tt, + ct, + json.dumps(remapped_callers), + ), + ) + self.con.commit() - self.make_pstats_compatible() - self.print_stats("tottime") - cur = self.con.cursor() - cur.execute("CREATE TABLE total_time (time_ns INTEGER)") - cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) - self.con.commit() - self.con.close() + self.make_pstats_compatible() # Modifies self.stats and self.timings in-memory + self.print_stats("tottime") # Uses self.stats, prints to console + + cur = self.con.cursor() # New cursor + cur.execute("CREATE TABLE total_time (time_ns INTEGER)") + cur.execute("INSERT INTO total_time VALUES (?)", (self.total_tt,)) + self.con.commit() + self.con.close() + self.con = None # Mark connection as closed # filter any functions where we did not capture the return self.function_modules = [ @@ -245,18 +272,29 @@ def __exit__( def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 if event != "call": return - if self.timeout is not None and (time.time() - self.start_time) > self.timeout: + if None is not self.timeout and (time.time() - self.start_time) > self.timeout: sys.setprofile(None) threading.setprofile(None) console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.") return - code = frame.f_code + if self.disable or self._db_lock is None or self.con is None: + return - file_name = Path(code.co_filename).resolve() - # TODO : It currently doesn't log the last return call from the first function + code = frame.f_code + # Check function name first before resolving path if code.co_name in self.ignored_functions: return + + # Now resolve file path only if we need it + co_filename = code.co_filename + if co_filename in self.path_cache: + file_name = self.path_cache[co_filename] + else: + file_name = Path(co_filename).resolve() + self.path_cache[co_filename] = file_name + # TODO : It currently doesn't log the last return call from the first function + if not file_name.is_relative_to(self.project_root): return if not file_name.exists(): @@ -266,18 +304,29 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 class_name = None arguments = frame.f_locals try: - if ( - "self" in arguments - and hasattr(arguments["self"], "__class__") - and hasattr(arguments["self"].__class__, "__name__") - ): - class_name = arguments["self"].__class__.__name__ - elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): - class_name = arguments["cls"].__name__ + self_arg = arguments.get("self") + if self_arg is not None: + try: + class_name = self_arg.__class__.__name__ + except AttributeError: + cls_arg = arguments.get("cls") + if cls_arg is not None: + with contextlib.suppress(AttributeError): + class_name = cls_arg.__name__ + else: + cls_arg = arguments.get("cls") + if cls_arg is not None: + with contextlib.suppress(AttributeError): + class_name = cls_arg.__name__ except: # noqa: E722 # someone can override the getattr method and raise an exception. I'm looking at you wrapt return - function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" + + try: + function_qualified_name = f"{file_name}:{code.co_qualname}" + except AttributeError: + function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" + if function_qualified_name in self.ignored_qualified_functions: return if function_qualified_name not in self.function_count: @@ -310,49 +359,59 @@ def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911 # TODO: Also check if this function arguments are unique from the values logged earlier - cur = self.con.cursor() + with self._db_lock: + # Check connection again inside lock, in case __exit__ closed it. + if self.con is None: + return - t_ns = time.perf_counter_ns() - original_recursion_limit = sys.getrecursionlimit() - try: - # pickling can be a recursive operator, so we need to increase the recursion limit - sys.setrecursionlimit(10000) - # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class - # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory - # leaks, bad references or side effects when unpickling. - arguments = dict(arguments.items()) - if class_name and code.co_name == "__init__": - del arguments["self"] - local_vars = pickle.dumps(arguments, protocol=pickle.HIGHEST_PROTOCOL) - sys.setrecursionlimit(original_recursion_limit) - except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError): - # we retry with dill if pickle fails. It's slower but more comprehensive + cur = self.con.cursor() + + t_ns = time.perf_counter_ns() + original_recursion_limit = sys.getrecursionlimit() try: - local_vars = dill.dumps(arguments, protocol=dill.HIGHEST_PROTOCOL) + # pickling can be a recursive operator, so we need to increase the recursion limit + sys.setrecursionlimit(10000) + # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class + # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory + # leaks, bad references or side effects when unpickling. + arguments_copy = dict(arguments.items()) # Use the local 'arguments' from frame.f_locals + if class_name and code.co_name == "__init__" and "self" in arguments_copy: + del arguments_copy["self"] + local_vars = pickle.dumps(arguments_copy, protocol=pickle.HIGHEST_PROTOCOL) sys.setrecursionlimit(original_recursion_limit) + except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError): + # we retry with dill if pickle fails. It's slower but more comprehensive + try: + sys.setrecursionlimit(10000) # Ensure limit is high for dill too + # arguments_copy should be used here as well if defined above + local_vars = dill.dumps( + arguments_copy if "arguments_copy" in locals() else dict(arguments.items()), + protocol=dill.HIGHEST_PROTOCOL, + ) + sys.setrecursionlimit(original_recursion_limit) + + except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError): + self.function_count[function_qualified_name] -= 1 + return - except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError): - # give up - self.function_count[function_qualified_name] -= 1 - return - cur.execute( - "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", - ( - event, - code.co_name, - class_name, - str(file_name), - frame.f_lineno, - frame.f_back.__hash__(), - t_ns, - local_vars, - ), - ) - self.trace_count += 1 - self.next_insert -= 1 - if self.next_insert == 0: - self.next_insert = 1000 - self.con.commit() + cur.execute( + "INSERT INTO function_calls VALUES(?, ?, ?, ?, ?, ?, ?, ?)", + ( + event, + code.co_name, + class_name, + str(file_name), + frame.f_lineno, + frame.f_back.__hash__(), + t_ns, + local_vars, + ), + ) + self.trace_count += 1 + self.next_insert -= 1 + if self.next_insert == 0: + self.next_insert = 1000 + self.con.commit() def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: # profiler section @@ -475,8 +534,9 @@ def trace_dispatch_return(self, frame: FrameType, t: int) -> int: cc = cc + 1 if pfn in callers: - callers[pfn] = callers[pfn] + 1 # TODO: gather more - # stats such as the amount of time added to ct courtesy + # Increment call count between these functions + callers[pfn] = callers[pfn] + 1 + # Note: This tracks stats such as the amount of time added to ct # of this specific call, and the contribution to cc # courtesy of this call. else: @@ -703,7 +763,7 @@ def create_stats(self) -> None: def snapshot_stats(self) -> None: self.stats = {} - for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items(): + for func, (cc, _ns, tt, ct, caller_dict) in list(self.timings.items()): callers = caller_dict.copy() nc = 0 for callcnt in callers.values(): diff --git a/poetry.lock b/poetry.lock index 04cfeae09..a65546a92 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1704,25 +1704,6 @@ aspect = ["aspectlib"] elasticsearch = ["elasticsearch"] histogram = ["pygal", "pygaljs", "setuptools"] -[[package]] -name = "pytest-cov" -version = "6.1.1" -description = "Pytest plugin for measuring coverage." -optional = false -python-versions = ">=3.9" -groups = ["dev"] -files = [ - {file = "pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde"}, - {file = "pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a"}, -] - -[package.dependencies] -coverage = {version = ">=7.5", extras = ["toml"]} -pytest = ">=4.6" - -[package.extras] -testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] - [[package]] name = "pytest-timeout" version = "2.4.0" @@ -2686,4 +2667,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.9" -content-hash = "1a73e9db33e3884cf1cc6e3371816aebd20831845ef9bf671be315e659480e86" +content-hash = "2a3d21343a2bbe712cc2272f26fe6b7adcb793f9bcd99c67ce4fbc9709416218" diff --git a/pyproject.toml b/pyproject.toml index cb8f2c7d9..0cb1b4a72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,6 @@ types-cffi = ">=1.16.0.20240331" types-openpyxl = ">=3.1.5.20241020" types-regex = ">=2024.9.11.20240912" types-python-dateutil = ">=2.9.0.20241003" -pytest-cov = "^6.0.0" pytest-benchmark = ">=5.1.0" types-gevent = "^24.11.0.20241230" types-greenlet = "^3.1.0.20241221" diff --git a/tests/test_tracer.py b/tests/test_tracer.py new file mode 100644 index 000000000..8708ebd32 --- /dev/null +++ b/tests/test_tracer.py @@ -0,0 +1,407 @@ +import contextlib +import sqlite3 +import sys +import tempfile +import threading +import time +from collections.abc import Generator +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest + +from codeflash.tracer import FakeCode, FakeFrame, Tracer + + +class TestFakeCode: + def test_fake_code_initialization(self) -> None: + fake_code = FakeCode("test.py", 10, "test_function") + assert fake_code.co_filename == "test.py" + assert fake_code.co_line == 10 + assert fake_code.co_name == "test_function" + assert fake_code.co_firstlineno == 0 + + def test_fake_code_repr(self) -> None: + fake_code = FakeCode("test.py", 10, "test_function") + expected_repr = repr(("test.py", 10, "test_function", None)) + assert repr(fake_code) == expected_repr + + +class TestFakeFrame: + def test_fake_frame_initialization(self) -> None: + fake_code = FakeCode("test.py", 10, "test_function") + fake_frame = FakeFrame(fake_code, None) + assert fake_frame.f_code == fake_code + assert fake_frame.f_back is None + assert fake_frame.f_locals == {} + + def test_fake_frame_with_prior(self) -> None: + fake_code1 = FakeCode("test1.py", 5, "func1") + fake_code2 = FakeCode("test2.py", 10, "func2") + fake_frame1 = FakeFrame(fake_code1, None) + fake_frame2 = FakeFrame(fake_code2, fake_frame1) + + assert fake_frame2.f_code == fake_code2 + assert fake_frame2.f_back == fake_frame1 + + +class TestTracer: + @pytest.fixture + def temp_config_file(self) -> Generator[Path, None, None]: + """Create a temporary pyproject.toml config file.""" + # Create a temporary directory structure + temp_dir = Path(tempfile.mkdtemp()) + tests_dir = temp_dir / "tests" + tests_dir.mkdir(exist_ok=True) + + # Use the current working directory as module root so test files are included + current_dir = Path.cwd() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False, dir=temp_dir) as f: + f.write(f""" +[tool.codeflash] +module-root = "{current_dir}" +tests-root = "{tests_dir}" +test-framework = "pytest" +ignore-paths = [] +""") + config_path = Path(f.name) + yield config_path + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + @pytest.fixture + def temp_trace_file(self) -> Generator[Path, None, None]: + """Create a temporary trace file path.""" + with tempfile.NamedTemporaryFile(suffix=".trace", delete=False) as f: + trace_path = Path(f.name) + trace_path.unlink(missing_ok=True) # Remove the file, we just want the path + yield trace_path + trace_path.unlink(missing_ok=True) + + @pytest.fixture(autouse=True) + def reset_tracer_state(self) -> Generator[None, None, None]: + """Reset the tracer used_once state before each test.""" + # Reset the class variable if it exists + if hasattr(Tracer, "used_once"): + delattr(Tracer, "used_once") + yield + # Reset after test as well + if hasattr(Tracer, "used_once"): + delattr(Tracer, "used_once") + + def test_tracer_disabled_by_environment(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set.""" + with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}): + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + assert tracer.disable is True + + def test_tracer_disabled_with_existing_profiler(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test that tracer is disabled when another profiler is running.""" + def dummy_profiler(_frame: object, _event: str, _arg: object) -> object: + return dummy_profiler + + sys.setprofile(dummy_profiler) + try: + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + assert tracer.disable is True + finally: + sys.setprofile(None) + + def test_tracer_initialization_normal(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test normal tracer initialization.""" + tracer = Tracer( + output=str(temp_trace_file), + functions=["test_func"], + max_function_count=128, + timeout=10, + config_file_path=temp_config_file + ) + + assert tracer.disable is False + assert tracer.output_file == temp_trace_file.resolve() + assert tracer.functions == ["test_func"] + assert tracer.max_function_count == 128 + assert tracer.timeout == 10 + assert hasattr(tracer, "_db_lock") + assert getattr(tracer, "_db_lock") is not None + + def test_tracer_timeout_validation(self, temp_config_file: Path, temp_trace_file: Path) -> None: + with pytest.raises(AssertionError): + Tracer( + output=str(temp_trace_file), + timeout=0, + config_file_path=temp_config_file + ) + + with pytest.raises(AssertionError): + Tracer( + output=str(temp_trace_file), + timeout=-5, + config_file_path=temp_config_file + ) + + def test_tracer_context_manager_disabled(self, temp_config_file: Path, temp_trace_file: Path) -> None: + tracer = Tracer( + output=str(temp_trace_file), + disable=True, + config_file_path=temp_config_file + ) + + with tracer: + pass + + assert not temp_trace_file.exists() + + + + def test_tracer_function_filtering(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test that tracer respects function filtering.""" + if hasattr(Tracer, "used_once"): + delattr(Tracer, "used_once") + + def test_function() -> int: + return 42 + + def other_function() -> int: + return 24 + + tracer = Tracer( + output=str(temp_trace_file), + functions=["test_function"], + config_file_path=temp_config_file + ) + + with tracer: + test_function() + other_function() + + if temp_trace_file.exists(): + con = sqlite3.connect(temp_trace_file) + cursor = con.cursor() + + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") + if cursor.fetchone(): + cursor.execute("SELECT function FROM function_calls WHERE function = 'test_function'") + cursor.fetchall() + + cursor.execute("SELECT function FROM function_calls WHERE function = 'other_function'") + cursor.fetchall() + + con.close() + + + def test_tracer_max_function_count(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def counting_function(n: int) -> int: + return n * 2 + + tracer = Tracer( + output=str(temp_trace_file), + max_function_count=3, + config_file_path=temp_config_file + ) + + with tracer: + for i in range(5): + counting_function(i) + + assert tracer.trace_count <= 3, "Tracer should limit the number of traced functions to max_function_count" + + def test_tracer_timeout_functionality(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def slow_function() -> str: + time.sleep(0.1) + return "done" + + tracer = Tracer( + output=str(temp_trace_file), + timeout=1, # 1 second timeout + config_file_path=temp_config_file + ) + + with tracer: + slow_function() + + def test_tracer_threading_safety(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test that tracer works correctly with threading.""" + results = [] + + def thread_function(n: int) -> None: + results.append(n * 2) + + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + + with tracer: + threads = [] + for i in range(3): + thread = threading.Thread(target=thread_function, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(results) == 3 + + def test_simulate_call(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test the simulate_call method.""" + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + + tracer.simulate_call("test_simulation") + + def test_simulate_cmd_complete(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test the simulate_cmd_complete method.""" + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + + tracer.simulate_call("test") + tracer.simulate_cmd_complete() + + def test_runctx_method(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test the runctx method for executing code with tracing.""" + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + + global_vars = {"x": 10} + local_vars = {} + + result = tracer.runctx("y = x * 2", global_vars, local_vars) + + assert result == tracer + assert local_vars["y"] == 20 + + def test_tracer_handles_class_methods(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test that tracer correctly handles class methods.""" + # Ensure tracer hasn't been used yet in this test + if hasattr(Tracer, "used_once"): + delattr(Tracer, "used_once") + + class TestClass: + def instance_method(self) -> str: + return "instance" + + @classmethod + def class_method(cls) -> str: + return "class" + + @staticmethod + def static_method() -> str: + return "static" + + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + + with tracer: + obj = TestClass() + instance_result = obj.instance_method() + class_result = TestClass.class_method() + static_result = TestClass.static_method() + + + + if temp_trace_file.exists(): + con = sqlite3.connect(temp_trace_file) + cursor = con.cursor() + + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") + if cursor.fetchone(): + # Query for all function calls + cursor.execute("SELECT function, classname FROM function_calls") + calls = cursor.fetchall() + + function_names = [call[0] for call in calls] + class_names = [call[1] for call in calls if call[1] is not None] + + assert "instance_method" in function_names + assert "class_method" in function_names + assert "static_method" in function_names + assert "TestClass" in class_names + else: + pytest.fail("No function_calls table found in trace file") + con.close() + + + + + + def test_tracer_handles_exceptions_gracefully(self, temp_config_file: Path, temp_trace_file: Path) -> None: + """Test that tracer handles exceptions in traced code gracefully.""" + def failing_function() -> None: + raise ValueError("Test exception") + + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + + with tracer, contextlib.suppress(ValueError): + failing_function() + + + + + + def test_tracer_with_complex_arguments(self, temp_config_file: Path, temp_trace_file: Path) -> None: + def complex_function(data_dict: dict[str, Any], nested_list: list[list[int]], func_arg: object = lambda x: x) -> int: + return len(data_dict) + len(nested_list) + + tracer = Tracer( + output=str(temp_trace_file), + config_file_path=temp_config_file + ) + + expected_dict = {"key": "value", "nested": {"inner": "data"}} + expected_list = [[1, 2], [3, 4], [5, 6]] + expected_func = lambda x: x * 2 + + with tracer: + complex_function( + expected_dict, + expected_list, + func_arg=expected_func + ) + + if temp_trace_file.exists(): + con = sqlite3.connect(temp_trace_file) + cursor = con.cursor() + + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") + if cursor.fetchone(): + cursor.execute("SELECT args FROM function_calls WHERE function = 'complex_function'") + result = cursor.fetchone() + assert result is not None, "Function complex_function should be traced" + + # Deserialize the arguments + import pickle + traced_args = pickle.loads(result[0]) + + assert "data_dict" in traced_args + assert "nested_list" in traced_args + assert "func_arg" in traced_args + + assert traced_args["data_dict"] == expected_dict + assert traced_args["nested_list"] == expected_list + assert callable(traced_args["func_arg"]) + assert traced_args["func_arg"](2) == 4 + assert len(traced_args["nested_list"]) == 3 + else: + pytest.fail("No function_calls table found in trace file") + con.close()