diff --git a/.github/workflows/codeflash-optimize.yaml b/.github/workflows/codeflash-optimize.yaml index 6a08635bf..357269116 100644 --- a/.github/workflows/codeflash-optimize.yaml +++ b/.github/workflows/codeflash-optimize.yaml @@ -68,4 +68,4 @@ jobs: id: optimize_code run: | source .venv/bin/activate - poetry run codeflash + poetry run codeflash --benchmark diff --git a/.github/workflows/end-to-end-test-benchmark-bubblesort.yaml b/.github/workflows/end-to-end-test-benchmark-bubblesort.yaml new file mode 100644 index 000000000..53a59dac1 --- /dev/null +++ b/.github/workflows/end-to-end-test-benchmark-bubblesort.yaml @@ -0,0 +1,41 @@ +name: end-to-end-test + +on: + pull_request: + workflow_dispatch: + +jobs: + benchmark-bubble-sort-optimization: + runs-on: ubuntu-latest + env: + CODEFLASH_AIS_SERVER: prod + POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} + CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} + COLUMNS: 110 + MAX_RETRIES: 3 + RETRY_DELAY: 5 + EXPECTED_IMPROVEMENT_PCT: 5 + CODEFLASH_END_TO_END: 1 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Python 3.11 for CLI + uses: astral-sh/setup-uv@v5 + with: + python-version: 3.11.6 + + - name: Install dependencies (CLI) + run: | + uv tool install poetry + uv venv + source .venv/bin/activate + poetry install --with dev + + - name: Run Codeflash to optimize code + id: optimize_code_with_benchmarks + run: | + source .venv/bin/activate + poetry run python tests/scripts/end_to_end_test_benchmark_sort.py \ No newline at end of file diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index a1e7da8ea..f3b4ffca5 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -32,7 +32,7 @@ jobs: run: uvx poetry install --with dev - name: Unit tests - run: uvx poetry run pytest tests/ --cov --cov-report=xml + run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip" - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 diff --git a/code_to_optimize/bubble_sort.py b/code_to_optimize/bubble_sort.py index 787cc4a90..9e97f63a0 100644 --- a/code_to_optimize/bubble_sort.py +++ b/code_to_optimize/bubble_sort.py @@ -7,4 +7,4 @@ def sorter(arr): arr[j] = arr[j + 1] arr[j + 1] = temp print(f"result: {arr}") - return arr \ No newline at end of file + return arr diff --git a/code_to_optimize/bubble_sort_codeflash_trace.py b/code_to_optimize/bubble_sort_codeflash_trace.py new file mode 100644 index 000000000..48e9a412b --- /dev/null +++ b/code_to_optimize/bubble_sort_codeflash_trace.py @@ -0,0 +1,64 @@ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +@codeflash_trace +def sorter(arr): + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + return arr + +@codeflash_trace +def recursive_bubble_sort(arr, n=None): + # Initialize n if not provided + if n is None: + n = len(arr) + + # Base case: if n is 1, the array is already sorted + if n == 1: + return arr + + # One pass of bubble sort - move the largest element to the end + for i in range(n - 1): + if arr[i] > arr[i + 1]: + arr[i], arr[i + 1] = arr[i + 1], arr[i] + + # Recursively sort the remaining n-1 elements + return recursive_bubble_sort(arr, n - 1) + +class Sorter: + @codeflash_trace + def __init__(self, arr): + self.arr = arr + @codeflash_trace + def sorter(self, multiplier): + for i in range(len(self.arr)): + for j in range(len(self.arr) - 1): + if self.arr[j] > self.arr[j + 1]: + temp = self.arr[j] + self.arr[j] = self.arr[j + 1] + self.arr[j + 1] = temp + return self.arr * multiplier + + @staticmethod + @codeflash_trace + def sort_static(arr): + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + return arr + + @classmethod + @codeflash_trace + def sort_class(cls, arr): + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + return arr diff --git a/code_to_optimize/bubble_sort_multithread.py b/code_to_optimize/bubble_sort_multithread.py new file mode 100644 index 000000000..e71be4816 --- /dev/null +++ b/code_to_optimize/bubble_sort_multithread.py @@ -0,0 +1,23 @@ +# from code_to_optimize.bubble_sort_codeflash_trace import sorter +from code_to_optimize.bubble_sort_codeflash_trace import sorter +import concurrent.futures + + +def multithreaded_sorter(unsorted_lists: list[list[int]]) -> list[list[int]]: + # Create a list to store results in the correct order + sorted_lists = [None] * len(unsorted_lists) + + # Use ThreadPoolExecutor to manage threads + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + # Submit all sorting tasks and map them to their original indices + future_to_index = { + executor.submit(sorter, unsorted_list): i + for i, unsorted_list in enumerate(unsorted_lists) + } + + # Collect results as they complete + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + sorted_lists[index] = future.result() + + return sorted_lists \ No newline at end of file diff --git a/code_to_optimize/bubble_sort_picklepatch_test_unused_socket.py b/code_to_optimize/bubble_sort_picklepatch_test_unused_socket.py new file mode 100644 index 000000000..2b75a8c34 --- /dev/null +++ b/code_to_optimize/bubble_sort_picklepatch_test_unused_socket.py @@ -0,0 +1,18 @@ + +from codeflash.benchmarking.codeflash_trace import codeflash_trace + + +@codeflash_trace +def bubble_sort_with_unused_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + + return sorted(numbers) + +@codeflash_trace +def bubble_sort_with_used_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + socket = data_container.get('socket') + socket.send("Hello from the optimized function!") + return sorted(numbers) diff --git a/code_to_optimize/bubble_sort_picklepatch_test_used_socket.py b/code_to_optimize/bubble_sort_picklepatch_test_used_socket.py new file mode 100644 index 000000000..390e090cd --- /dev/null +++ b/code_to_optimize/bubble_sort_picklepatch_test_used_socket.py @@ -0,0 +1,46 @@ +from codeflash.benchmarking.codeflash_trace import codeflash_trace + +@codeflash_trace +def bubble_sort_with_used_socket(data_container): + """ + Performs a bubble sort on a list within the data_container. The data container has the following schema: + - 'numbers' (list): The list to be sorted. + - 'socket' (socket): A socket + + Args: + data_container: A dictionary with at least 'numbers' (list) and 'socket' keys + + Returns: + list: The sorted list of numbers + """ + # Extract the list to sort and socket + numbers = data_container.get('numbers', []).copy() + socket = data_container.get('socket') + + # Track swap count + swap_count = 0 + + # Classic bubble sort implementation + n = len(numbers) + for i in range(n): + # Flag to optimize by detecting if no swaps occurred + swapped = False + + # Last i elements are already in place + for j in range(0, n - i - 1): + # Swap if the element is greater than the next element + if numbers[j] > numbers[j + 1]: + # Perform the swap + numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j] + swapped = True + swap_count += 1 + + # If no swapping occurred in this pass, the list is sorted + if not swapped: + break + + # Send final summary + summary = f"Bubble sort completed with {swap_count} swaps" + socket.send(summary.encode()) + + return numbers \ No newline at end of file diff --git a/code_to_optimize/process_and_bubble_sort.py b/code_to_optimize/process_and_bubble_sort.py new file mode 100644 index 000000000..94359e599 --- /dev/null +++ b/code_to_optimize/process_and_bubble_sort.py @@ -0,0 +1,28 @@ +from code_to_optimize.bubble_sort import sorter + + +def calculate_pairwise_products(arr): + """ + Calculate the average of all pairwise products in the array. + """ + sum_of_products = 0 + count = 0 + + for i in range(len(arr)): + for j in range(len(arr)): + if i != j: + sum_of_products += arr[i] * arr[j] + count += 1 + + # The average of all pairwise products + return sum_of_products / count if count > 0 else 0 + + +def compute_and_sort(arr): + # Compute pairwise sums average + pairwise_average = calculate_pairwise_products(arr) + + # Call sorter function + sorter(arr.copy()) + + return pairwise_average diff --git a/code_to_optimize/process_and_bubble_sort_codeflash_trace.py b/code_to_optimize/process_and_bubble_sort_codeflash_trace.py new file mode 100644 index 000000000..37c2abab8 --- /dev/null +++ b/code_to_optimize/process_and_bubble_sort_codeflash_trace.py @@ -0,0 +1,28 @@ +from code_to_optimize.bubble_sort import sorter +from codeflash.benchmarking.codeflash_trace import codeflash_trace + +def calculate_pairwise_products(arr): + """ + Calculate the average of all pairwise products in the array. + """ + sum_of_products = 0 + count = 0 + + for i in range(len(arr)): + for j in range(len(arr)): + if i != j: + sum_of_products += arr[i] * arr[j] + count += 1 + + # The average of all pairwise products + return sum_of_products / count if count > 0 else 0 + +@codeflash_trace +def compute_and_sort(arr): + # Compute pairwise sums average + pairwise_average = calculate_pairwise_products(arr) + + # Call sorter function + sorter(arr.copy()) + + return pairwise_average diff --git a/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py b/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py new file mode 100644 index 000000000..3d7b24a6c --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks/test_benchmark_bubble_sort.py @@ -0,0 +1,13 @@ +import pytest + +from code_to_optimize.bubble_sort import sorter + + +def test_sort(benchmark): + result = benchmark(sorter, list(reversed(range(500)))) + assert result == list(range(500)) + +# This should not be picked up as a benchmark test +def test_sort2(): + result = sorter(list(reversed(range(500)))) + assert result == list(range(500)) diff --git a/code_to_optimize/tests/pytest/benchmarks/test_process_and_sort.py b/code_to_optimize/tests/pytest/benchmarks/test_process_and_sort.py new file mode 100644 index 000000000..8d31c926a --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks/test_process_and_sort.py @@ -0,0 +1,8 @@ +from code_to_optimize.process_and_bubble_sort import compute_and_sort +from code_to_optimize.bubble_sort import sorter +def test_compute_and_sort(benchmark): + result = benchmark(compute_and_sort, list(reversed(range(500)))) + assert result == 62208.5 + +def test_no_func(benchmark): + benchmark(sorter, list(reversed(range(500)))) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/benchmarks_multithread/test_multithread_sort.py b/code_to_optimize/tests/pytest/benchmarks_multithread/test_multithread_sort.py new file mode 100644 index 000000000..4a5c68a2b --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_multithread/test_multithread_sort.py @@ -0,0 +1,4 @@ +from code_to_optimize.bubble_sort_multithread import multithreaded_sorter + +def test_benchmark_sort(benchmark): + benchmark(multithreaded_sorter, [list(range(1000)) for i in range (10)]) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/benchmarks_socket_test/test_socket.py b/code_to_optimize/tests/pytest/benchmarks_socket_test/test_socket.py new file mode 100644 index 000000000..bd05af487 --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_socket_test/test_socket.py @@ -0,0 +1,20 @@ +import socket + +from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket +from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket + +def test_socket_picklepatch(benchmark): + s1, s2 = socket.socketpair() + data = { + "numbers": list(reversed(range(500))), + "socket": s1 + } + benchmark(bubble_sort_with_unused_socket, data) + +def test_used_socket_picklepatch(benchmark): + s1, s2 = socket.socketpair() + data = { + "numbers": list(reversed(range(500))), + "socket": s1 + } + benchmark(bubble_sort_with_used_socket, data) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/benchmarks_test/test_benchmark_bubble_sort_example.py b/code_to_optimize/tests/pytest/benchmarks_test/test_benchmark_bubble_sort_example.py new file mode 100644 index 000000000..21f9755a5 --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_test/test_benchmark_bubble_sort_example.py @@ -0,0 +1,26 @@ +import pytest + +from code_to_optimize.bubble_sort_codeflash_trace import sorter, Sorter + + +def test_sort(benchmark): + result = benchmark(sorter, list(reversed(range(500)))) + assert result == list(range(500)) + +# This should not be picked up as a benchmark test +def test_sort2(): + result = sorter(list(reversed(range(500)))) + assert result == list(range(500)) + +def test_class_sort(benchmark): + obj = Sorter(list(reversed(range(100)))) + result1 = benchmark(obj.sorter, 2) + +def test_class_sort2(benchmark): + result2 = benchmark(Sorter.sort_class, list(reversed(range(100)))) + +def test_class_sort3(benchmark): + result3 = benchmark(Sorter.sort_static, list(reversed(range(100)))) + +def test_class_sort4(benchmark): + result4 = benchmark(Sorter, [1,2,3]) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/benchmarks_test/test_process_and_sort_example.py b/code_to_optimize/tests/pytest/benchmarks_test/test_process_and_sort_example.py new file mode 100644 index 000000000..bcd42eab9 --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_test/test_process_and_sort_example.py @@ -0,0 +1,8 @@ +from code_to_optimize.process_and_bubble_sort_codeflash_trace import compute_and_sort +from code_to_optimize.bubble_sort_codeflash_trace import sorter +def test_compute_and_sort(benchmark): + result = benchmark(compute_and_sort, list(reversed(range(500)))) + assert result == 62208.5 + +def test_no_func(benchmark): + benchmark(sorter, list(reversed(range(500)))) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/benchmarks_test/test_recursive_example.py b/code_to_optimize/tests/pytest/benchmarks_test/test_recursive_example.py new file mode 100644 index 000000000..689b1f9ff --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_test/test_recursive_example.py @@ -0,0 +1,6 @@ +from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort + + +def test_recursive_sort(benchmark): + result = benchmark(recursive_bubble_sort, list(reversed(range(500)))) + assert result == list(range(500)) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/benchmarks_test_decorator/test_benchmark_decorator.py b/code_to_optimize/tests/pytest/benchmarks_test_decorator/test_benchmark_decorator.py new file mode 100644 index 000000000..b924bee7f --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_test_decorator/test_benchmark_decorator.py @@ -0,0 +1,11 @@ +import pytest +from code_to_optimize.bubble_sort_codeflash_trace import sorter + +def test_benchmark_sort(benchmark): + @benchmark + def do_sort(): + sorter(list(reversed(range(500)))) + +@pytest.mark.benchmark(group="benchmark_decorator") +def test_pytest_mark(benchmark): + benchmark(sorter, list(reversed(range(500)))) \ No newline at end of file diff --git a/codeflash/benchmarking/__init__.py b/codeflash/benchmarking/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py new file mode 100644 index 000000000..35232f954 --- /dev/null +++ b/codeflash/benchmarking/codeflash_trace.py @@ -0,0 +1,179 @@ +import functools +import os +import pickle +import sqlite3 +import threading +import time +from typing import Callable + +from codeflash.picklepatch.pickle_patcher import PicklePatcher + + +class CodeflashTrace: + """Decorator class that traces and profiles function execution.""" + + def __init__(self) -> None: + self.function_calls_data = [] + self.function_call_count = 0 + self.pickle_count_limit = 1000 + self._connection = None + self._trace_path = None + self._thread_local = threading.local() + self._thread_local.active_functions = set() + + def setup(self, trace_path: str) -> None: + """Set up the database connection for direct writing. + + Args: + trace_path: Path to the trace database file + + """ + try: + self._trace_path = trace_path + self._connection = sqlite3.connect(self._trace_path) + cur = self._connection.cursor() + cur.execute("PRAGMA synchronous = OFF") + cur.execute("PRAGMA journal_mode = MEMORY") + cur.execute( + "CREATE TABLE IF NOT EXISTS benchmark_function_timings(" + "function_name TEXT, class_name TEXT, module_name TEXT, file_path TEXT," + "benchmark_function_name TEXT, benchmark_module_path TEXT, benchmark_line_number INTEGER," + "function_time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)" + ) + self._connection.commit() + except Exception as e: + print(f"Database setup error: {e}") + if self._connection: + self._connection.close() + self._connection = None + raise + + def write_function_timings(self) -> None: + """Write function call data directly to the database. + + Args: + data: List of function call data tuples to write + + """ + if not self.function_calls_data: + return # No data to write + + if self._connection is None and self._trace_path is not None: + self._connection = sqlite3.connect(self._trace_path) + + try: + cur = self._connection.cursor() + # Insert data into the benchmark_function_timings table + cur.executemany( + "INSERT INTO benchmark_function_timings" + "(function_name, class_name, module_name, file_path, benchmark_function_name, " + "benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + self.function_calls_data + ) + self._connection.commit() + self.function_calls_data = [] + except Exception as e: + print(f"Error writing to function timings database: {e}") + if self._connection: + self._connection.rollback() + raise + + def open(self) -> None: + """Open the database connection.""" + if self._connection is None: + self._connection = sqlite3.connect(self._trace_path) + + def close(self) -> None: + """Close the database connection.""" + if self._connection: + self._connection.close() + self._connection = None + + def __call__(self, func: Callable) -> Callable: + """Use as a decorator to trace function execution. + + Args: + func: The function to be decorated + + Returns: + The wrapped function + + """ + func_id = (func.__module__,func.__name__) + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Initialize thread-local active functions set if it doesn't exist + if not hasattr(self._thread_local, "active_functions"): + self._thread_local.active_functions = set() + # If it's in a recursive function, just return the result + if func_id in self._thread_local.active_functions: + return func(*args, **kwargs) + # Track active functions so we can detect recursive functions + self._thread_local.active_functions.add(func_id) + # Measure execution time + start_time = time.thread_time_ns() + result = func(*args, **kwargs) + end_time = time.thread_time_ns() + # Calculate execution time + execution_time = end_time - start_time + self.function_call_count += 1 + + # Check if currently in pytest benchmark fixture + if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False": + self._thread_local.active_functions.remove(func_id) + return result + # Get benchmark info from environment + benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "") + benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "") + benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "") + # Get class name + class_name = "" + qualname = func.__qualname__ + if "." in qualname: + class_name = qualname.split(".")[0] + + # Limit pickle count so memory does not explode + if self.function_call_count > self.pickle_count_limit: + print("Pickle limit reached") + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + self.function_calls_data.append( + (func.__name__, class_name, func.__module__, func.__code__.co_filename, + benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, + overhead_time, None, None) + ) + return result + + try: + # Pickle the arguments + pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) + pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as e: + print(f"Error pickling arguments for function {func.__name__}: {e}") + # Add to the list of function calls without pickled args. Used for timing info only + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + self.function_calls_data.append( + (func.__name__, class_name, func.__module__, func.__code__.co_filename, + benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, + overhead_time, None, None) + ) + return result + # Flush to database every 100 calls + if len(self.function_calls_data) > 100: + self.write_function_timings() + + # Add to the list of function calls with pickled args, to be used for replay tests + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + self.function_calls_data.append( + (func.__name__, class_name, func.__module__, func.__code__.co_filename, + benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, + overhead_time, pickled_args, pickled_kwargs) + ) + return result + return wrapper + +# Create a singleton instance +codeflash_trace = CodeflashTrace() diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py new file mode 100644 index 000000000..044b0b0a4 --- /dev/null +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -0,0 +1,117 @@ +from pathlib import Path + +import isort +import libcst as cst + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize + + +class AddDecoratorTransformer(cst.CSTTransformer): + def __init__(self, target_functions: set[tuple[str, str]]) -> None: + super().__init__() + self.target_functions = target_functions + self.added_codeflash_trace = False + self.class_name = "" + self.function_name = "" + self.decorator = cst.Decorator( + decorator=cst.Name(value="codeflash_trace") + ) + + def leave_ClassDef(self, original_node, updated_node): + if self.class_name == original_node.name.value: + self.class_name = "" # Even if nested classes are not visited, this function is still called on them + return updated_node + + def visit_ClassDef(self, node): + if self.class_name: # Don't go into nested class + return False + self.class_name = node.name.value + + def visit_FunctionDef(self, node): + if self.function_name: # Don't go into nested function + return False + self.function_name = node.name.value + + def leave_FunctionDef(self, original_node, updated_node): + if self.function_name == original_node.name.value: + self.function_name = "" + if (self.class_name, original_node.name.value) in self.target_functions: + # Add the new decorator after any existing decorators, so it gets executed first + updated_decorators = list(updated_node.decorators) + [self.decorator] + self.added_codeflash_trace = True + return updated_node.with_changes( + decorators=updated_decorators + ) + + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + # Create import statement for codeflash_trace + if not self.added_codeflash_trace: + return updated_node + import_stmt = cst.SimpleStatementLine( + body=[ + cst.ImportFrom( + module=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="codeflash"), + attr=cst.Name(value="benchmarking") + ), + attr=cst.Name(value="codeflash_trace") + ), + names=[ + cst.ImportAlias( + name=cst.Name(value="codeflash_trace") + ) + ] + ) + ] + ) + + # Insert at the beginning of the file. We'll use isort later to sort the imports. + new_body = [import_stmt, *list(updated_node.body)] + + return updated_node.with_changes(body=new_body) + +def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str: + """Add codeflash_trace to a function. + + Args: + code: The source code as a string + function_to_optimize: The FunctionToOptimize instance containing function details + + Returns: + The modified source code as a string + + """ + target_functions = set() + for function_to_optimize in functions_to_optimize: + class_name = "" + if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef": + class_name = function_to_optimize.parents[0].name + target_functions.add((class_name, function_to_optimize.function_name)) + + transformer = AddDecoratorTransformer( + target_functions = target_functions, + ) + + module = cst.parse_module(code) + modified_module = module.visit(transformer) + return modified_module.code + + +def instrument_codeflash_trace_decorator( + file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] +) -> None: + """Instrument codeflash_trace decorator to functions to optimize.""" + for file_path, functions_to_optimize in file_to_funcs_to_optimize.items(): + original_code = file_path.read_text(encoding="utf-8") + new_code = add_codeflash_decorator_to_code( + original_code, + functions_to_optimize + ) + # Modify the code + modified_code = isort.code(code=new_code, float_to_top=True) + + # Write the modified code back to the file + file_path.write_text(modified_code, encoding="utf-8") diff --git a/codeflash/benchmarking/plugin/__init__.py b/codeflash/benchmarking/plugin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py new file mode 100644 index 000000000..313817041 --- /dev/null +++ b/codeflash/benchmarking/plugin/plugin.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import os +import sqlite3 +import sys +import time +from pathlib import Path + +import pytest + +from codeflash.benchmarking.codeflash_trace import codeflash_trace +from codeflash.code_utils.code_utils import module_name_from_file_path +from codeflash.models.models import BenchmarkKey + + +class CodeFlashBenchmarkPlugin: + def __init__(self) -> None: + self._trace_path = None + self._connection = None + self.project_root = None + self.benchmark_timings = [] + + def setup(self, trace_path:str, project_root:str) -> None: + try: + # Open connection + self.project_root = project_root + self._trace_path = trace_path + self._connection = sqlite3.connect(self._trace_path) + cur = self._connection.cursor() + cur.execute("PRAGMA synchronous = OFF") + cur.execute("PRAGMA journal_mode = MEMORY") + cur.execute( + "CREATE TABLE IF NOT EXISTS benchmark_timings(" + "benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER," + "benchmark_time_ns INTEGER)" + ) + self._connection.commit() + self.close() # Reopen only at the end of pytest session + except Exception as e: + print(f"Database setup error: {e}") + if self._connection: + self._connection.close() + self._connection = None + raise + + def write_benchmark_timings(self) -> None: + if not self.benchmark_timings: + return # No data to write + + if self._connection is None: + self._connection = sqlite3.connect(self._trace_path) + + try: + cur = self._connection.cursor() + # Insert data into the benchmark_timings table + cur.executemany( + "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", + self.benchmark_timings + ) + self._connection.commit() + self.benchmark_timings = [] # Clear the benchmark timings list + except Exception as e: + print(f"Error writing to benchmark timings database: {e}") + self._connection.rollback() + raise + def close(self) -> None: + if self._connection: + self._connection.close() + self._connection = None + + @staticmethod + def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]: + """Process the trace file and extract timing data for all functions. + + Args: + trace_path: Path to the trace file + + Returns: + A nested dictionary where: + - Outer keys are module_name.qualified_name (module.class.function) + - Inner keys are of type BenchmarkKey + - Values are function timing in milliseconds + + """ + # Initialize the result dictionary + result = {} + + # Connect to the SQLite database + connection = sqlite3.connect(trace_path) + cursor = connection.cursor() + + try: + # Query the function_calls table for all function calls + cursor.execute( + "SELECT module_name, class_name, function_name, " + "benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns " + "FROM benchmark_function_timings" + ) + + # Process each row + for row in cursor.fetchall(): + module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row + + # Create the function key (module_name.class_name.function_name) + if class_name: + qualified_name = f"{module_name}.{class_name}.{function_name}" + else: + qualified_name = f"{module_name}.{function_name}" + + # Create the benchmark key (file::function::line) + benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) + # Initialize the inner dictionary if needed + if qualified_name not in result: + result[qualified_name] = {} + + # If multiple calls to the same function in the same benchmark, + # add the times together + if benchmark_key in result[qualified_name]: + result[qualified_name][benchmark_key] += time_ns + else: + result[qualified_name][benchmark_key] = time_ns + + finally: + # Close the connection + connection.close() + + return result + + @staticmethod + def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: + """Extract total benchmark timings from trace files. + + Args: + trace_path: Path to the trace file + + Returns: + A dictionary mapping where: + - Keys are of type BenchmarkKey + - Values are total benchmark timing in milliseconds (with overhead subtracted) + + """ + # Initialize the result dictionary + result = {} + overhead_by_benchmark = {} + + # Connect to the SQLite database + connection = sqlite3.connect(trace_path) + cursor = connection.cursor() + + try: + # Query the benchmark_function_timings table to get total overhead for each benchmark + cursor.execute( + "SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) " + "FROM benchmark_function_timings " + "GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number" + ) + + # Process overhead information + for row in cursor.fetchall(): + benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row + benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) + overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case + + # Query the benchmark_timings table for total times + cursor.execute( + "SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns " + "FROM benchmark_timings" + ) + + # Process each row and subtract overhead + for row in cursor.fetchall(): + benchmark_file, benchmark_func, benchmark_line, time_ns = row + + # Create the benchmark key (file::function::line) + benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) + # Subtract overhead from total time + overhead = overhead_by_benchmark.get(benchmark_key, 0) + result[benchmark_key] = time_ns - overhead + + finally: + # Close the connection + connection.close() + + return result + + # Pytest hooks + @pytest.hookimpl + def pytest_sessionfinish(self, session, exitstatus): + """Execute after whole test run is completed.""" + # Write any remaining benchmark timings to the database + codeflash_trace.close() + if self.benchmark_timings: + self.write_benchmark_timings() + # Close the database connection + self.close() + + @staticmethod + def pytest_addoption(parser): + parser.addoption( + "--codeflash-trace", + action="store_true", + default=False, + help="Enable CodeFlash tracing" + ) + + @staticmethod + def pytest_plugin_registered(plugin, manager): + # Not necessary since run with -p no:benchmark, but just in case + if hasattr(plugin, "name") and plugin.name == "pytest-benchmark": + manager.unregister(plugin) + + @staticmethod + def pytest_configure(config): + """Register the benchmark marker.""" + config.addinivalue_line( + "markers", + "benchmark: mark test as a benchmark that should be run with codeflash tracing" + ) + @staticmethod + def pytest_collection_modifyitems(config, items): + # Skip tests that don't have the benchmark fixture + if not config.getoption("--codeflash-trace"): + return + + skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture") + for item in items: + # Check for direct benchmark fixture usage + has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames + + # Check for @pytest.mark.benchmark marker + has_marker = False + if hasattr(item, "get_closest_marker"): + marker = item.get_closest_marker("benchmark") + if marker is not None: + has_marker = True + + # Skip if neither fixture nor marker is present + if not (has_fixture or has_marker): + item.add_marker(skip_no_benchmark) + + # Benchmark fixture + class Benchmark: + def __init__(self, request): + self.request = request + + def __call__(self, func, *args, **kwargs): + """Handle both direct function calls and decorator usage.""" + if args or kwargs: + # Used as benchmark(func, *args, **kwargs) + return self._run_benchmark(func, *args, **kwargs) + # Used as @benchmark decorator + def wrapped_func(*args, **kwargs): + return func(*args, **kwargs) + result = self._run_benchmark(func) + return wrapped_func + + def _run_benchmark(self, func, *args, **kwargs): + """Actual benchmark implementation.""" + benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), + Path(codeflash_benchmark_plugin.project_root)) + benchmark_function_name = self.request.node.name + line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack + # Set env vars + os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name + os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path + os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) + os.environ["CODEFLASH_BENCHMARKING"] = "True" + # Run the function + start = time.time_ns() + result = func(*args, **kwargs) + end = time.time_ns() + # Reset the environment variable + os.environ["CODEFLASH_BENCHMARKING"] = "False" + + # Write function calls + codeflash_trace.write_function_timings() + # Reset function call count + codeflash_trace.function_call_count = 0 + # Add to the benchmark timings buffer + codeflash_benchmark_plugin.benchmark_timings.append( + (benchmark_module_path, benchmark_function_name, line_number, end - start)) + + return result + + @staticmethod + @pytest.fixture + def benchmark(request): + if not request.config.getoption("--codeflash-trace"): + return None + + return CodeFlashBenchmarkPlugin.Benchmark(request) + +codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin() diff --git a/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py new file mode 100644 index 000000000..232c39fa7 --- /dev/null +++ b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py @@ -0,0 +1,25 @@ +import sys +from pathlib import Path + +from codeflash.benchmarking.codeflash_trace import codeflash_trace +from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin + +benchmarks_root = sys.argv[1] +tests_root = sys.argv[2] +trace_file = sys.argv[3] +# current working directory +project_root = Path.cwd() +if __name__ == "__main__": + import pytest + + try: + codeflash_benchmark_plugin.setup(trace_file, project_root) + codeflash_trace.setup(trace_file) + exitcode = pytest.main( + [benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin] + ) # Errors will be printed to stdout, not stderr + + except Exception as e: + print(f"Failed to collect tests: {e!s}", file=sys.stderr) + exitcode = -1 + sys.exit(exitcode) diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py new file mode 100644 index 000000000..ee1107241 --- /dev/null +++ b/codeflash/benchmarking/replay_test.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import sqlite3 +import textwrap +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import isort + +from codeflash.cli_cmds.console import logger +from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods +from codeflash.verification.verification_utils import get_test_file_path + +if TYPE_CHECKING: + from collections.abc import Generator + + +def get_next_arg_and_return( + trace_file: str, benchmark_function_name:str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256 +) -> Generator[Any]: + db = sqlite3.connect(trace_file) + cur = db.cursor() + limit = num_to_get + + if class_name is not None: + cursor = cur.execute( + "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?", + (benchmark_function_name, function_name, file_path, class_name, limit), + ) + else: + cursor = cur.execute( + "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?", + (benchmark_function_name, function_name, file_path, limit), + ) + + while (val := cursor.fetchone()) is not None: + yield val[9], val[10] # pickled_args, pickled_kwargs + + +def get_function_alias(module: str, function_name: str) -> str: + return "_".join(module.split(".")) + "_" + function_name + + +def create_trace_replay_test_code( + trace_file: str, + functions_data: list[dict[str, Any]], + test_framework: str = "pytest", + max_run_count=256 +) -> str: + """Create a replay test for functions based on trace data. + + Args: + trace_file: Path to the SQLite database file + functions_data: List of dictionaries with function info extracted from DB + test_framework: 'pytest' or 'unittest' + max_run_count: Maximum number of runs to include in the test + + Returns: + A string containing the test code + + """ + assert test_framework in ["pytest", "unittest"] + + # Create Imports + imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle +{"import unittest" if test_framework == "unittest" else ""} +from codeflash.benchmarking.replay_test import get_next_arg_and_return +""" + + function_imports = [] + for func in functions_data: + module_name = func.get("module_name") + function_name = func.get("function_name") + class_name = func.get("class_name", "") + if class_name: + function_imports.append( + f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}" + ) + else: + function_imports.append( + f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}" + ) + + imports += "\n".join(function_imports) + + functions_to_optimize = sorted({func.get("function_name") for func in functions_data + if func.get("function_name") != "__init__"}) + metadata = f"""functions = {functions_to_optimize} +trace_file_path = r"{trace_file}" +""" + # Templates for different types of tests + test_function_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + ret = {function_name}(*args, **kwargs) + """ + ) + + test_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + function_name = "{orig_function_name}" + if not args: + raise ValueError("No arguments provided for the method.") + if function_name == "__init__": + ret = {class_name_alias}(*args[1:], **kwargs) + else: + instance = args[0] # self + ret = instance{method_name}(*args[1:], **kwargs) + """) + + test_class_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + if not args: + raise ValueError("No arguments provided for the method.") + ret = {class_name_alias}{method_name}(*args[1:], **kwargs) + """ + ) + test_static_method_body = textwrap.dedent( + """\ + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl){filter_variables} + ret = {class_name_alias}{method_name}(*args, **kwargs) + """ + ) + + # Create main body + + if test_framework == "unittest": + self = "self" + test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n" + else: + test_template = "" + self = "" + + for func in functions_data: + + module_name = func.get("module_name") + function_name = func.get("function_name") + class_name = func.get("class_name") + file_path = func.get("file_path") + benchmark_function_name = func.get("benchmark_function_name") + function_properties = func.get("function_properties") + if not class_name: + alias = get_function_alias(module_name, function_name) + test_body = test_function_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + function_name=alias, + file_path=file_path, + max_run_count=max_run_count, + ) + else: + class_name_alias = get_function_alias(module_name, class_name) + alias = get_function_alias(module_name, class_name + "_" + function_name) + + filter_variables = "" + # filter_variables = '\n args.pop("cls", None)' + method_name = "." + function_name if function_name != "__init__" else "" + if function_properties.is_classmethod: + test_body = test_class_method_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + file_path=file_path, + class_name_alias=class_name_alias, + class_name=class_name, + method_name=method_name, + max_run_count=max_run_count, + filter_variables=filter_variables, + ) + elif function_properties.is_staticmethod: + test_body = test_static_method_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + file_path=file_path, + class_name_alias=class_name_alias, + class_name=class_name, + method_name=method_name, + max_run_count=max_run_count, + filter_variables=filter_variables, + ) + else: + test_body = test_method_body.format( + benchmark_function_name=benchmark_function_name, + orig_function_name=function_name, + file_path=file_path, + class_name_alias=class_name_alias, + class_name=class_name, + method_name=method_name, + max_run_count=max_run_count, + filter_variables=filter_variables, + ) + + formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ") + + test_template += " " if test_framework == "unittest" else "" + test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n" + + return imports + "\n" + metadata + "\n" + test_template + +def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100) -> int: + """Generate multiple replay tests from the traced function calls, grouped by benchmark. + + Args: + trace_file_path: Path to the SQLite database file + output_dir: Directory to write the generated tests (if None, only returns the code) + test_framework: 'pytest' or 'unittest' + max_run_count: Maximum number of runs to include per function + + Returns: + Dictionary mapping benchmark names to generated test code + + """ + count = 0 + try: + # Connect to the database + conn = sqlite3.connect(trace_file_path.as_posix()) + cursor = conn.cursor() + + # Get distinct benchmark file paths + cursor.execute( + "SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings" + ) + benchmark_files = cursor.fetchall() + + # Generate a test for each benchmark file + for benchmark_file in benchmark_files: + benchmark_module_path = benchmark_file[0] + # Get all benchmarks and functions associated with this file path + cursor.execute( + "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings " + "WHERE benchmark_module_path = ?", + (benchmark_module_path,) + ) + + functions_data = [] + for row in cursor.fetchall(): + benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row + # Add this function to our list + functions_data.append({ + "function_name": function_name, + "class_name": class_name, + "file_path": file_path, + "module_name": module_name, + "benchmark_function_name": benchmark_function_name, + "benchmark_module_path": benchmark_module_path, + "benchmark_line_number": benchmark_line_number, + "function_properties": inspect_top_level_functions_or_methods( + file_name=Path(file_path), + function_or_method_name=function_name, + class_name=class_name, + ) + }) + + if not functions_data: + logger.info(f"No benchmark test functions found in {benchmark_module_path}") + continue + # Generate the test code for this benchmark + test_code = create_trace_replay_test_code( + trace_file=trace_file_path.as_posix(), + functions_data=functions_data, + test_framework=test_framework, + max_run_count=max_run_count, + ) + test_code = isort.code(test_code) + output_file = get_test_file_path( + test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay" + ) + # Write test code to file, parents = true + output_dir.mkdir(parents=True, exist_ok=True) + output_file.write_text(test_code, "utf-8") + count += 1 + + conn.close() + except Exception as e: + logger.info(f"Error generating replay tests: {e}") + + return count diff --git a/codeflash/benchmarking/trace_benchmarks.py b/codeflash/benchmarking/trace_benchmarks.py new file mode 100644 index 000000000..8f68030cb --- /dev/null +++ b/codeflash/benchmarking/trace_benchmarks.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import re +import subprocess +from pathlib import Path + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE + + +def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root: Path, trace_file: Path, timeout:int = 300) -> None: + result = subprocess.run( + [ + SAFE_SYS_EXECUTABLE, + Path(__file__).parent / "pytest_new_process_trace_benchmarks.py", + benchmarks_root, + tests_root, + trace_file, + ], + cwd=project_root, + check=False, + capture_output=True, + text=True, + env={"PYTHONPATH": str(project_root)}, + timeout=timeout, + ) + if result.returncode != 0: + if "ERROR collecting" in result.stdout: + # Pattern matches "===== ERRORS =====" (any number of =) and captures everything after + error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)" + match = re.search(error_pattern, result.stdout) + error_section = match.group(1) if match else result.stdout + elif "FAILURES" in result.stdout: + # Pattern matches "===== FAILURES =====" (any number of =) and captures everything after + error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)" + match = re.search(error_pattern, result.stdout) + error_section = match.group(1) if match else result.stdout + else: + error_section = result.stdout + logger.warning( + f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}" + ) diff --git a/codeflash/benchmarking/utils.py b/codeflash/benchmarking/utils.py new file mode 100644 index 000000000..da09cd57a --- /dev/null +++ b/codeflash/benchmarking/utils.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import shutil +from typing import TYPE_CHECKING, Optional + +from rich.console import Console +from rich.table import Table + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.time_utils import humanize_runtime +from codeflash.models.models import BenchmarkDetail, ProcessedBenchmarkInfo +from codeflash.result.critic import performance_gain + +if TYPE_CHECKING: + from codeflash.models.models import BenchmarkKey + + +def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], + total_benchmark_timings: dict[BenchmarkKey, int]) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: + function_to_result = {} + # Process each function's benchmark data + for func_path, test_times in function_benchmark_timings.items(): + # Sort by percentage (highest first) + sorted_tests = [] + for benchmark_key, func_time in test_times.items(): + total_time = total_benchmark_timings.get(benchmark_key, 0) + if func_time > total_time: + logger.debug(f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}") + # If the function time is greater than total time, likely to have multithreading / multiprocessing issues. + # Do not try to project the optimization impact for this function. + sorted_tests.append((benchmark_key, 0.0, 0.0, 0.0)) + elif total_time > 0: + percentage = (func_time / total_time) * 100 + # Convert nanoseconds to milliseconds + func_time_ms = func_time / 1_000_000 + total_time_ms = total_time / 1_000_000 + sorted_tests.append((benchmark_key, total_time_ms, func_time_ms, percentage)) + sorted_tests.sort(key=lambda x: x[3], reverse=True) + function_to_result[func_path] = sorted_tests + return function_to_result + + +def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None: + + try: + terminal_width = int(shutil.get_terminal_size().columns * 0.9) + except Exception: + terminal_width = 120 # Fallback width + console = Console(width = terminal_width) + for func_path, sorted_tests in function_to_results.items(): + console.print() + function_name = func_path.split(":")[-1] + + # Create a table for this function + table = Table(title=f"Function: {function_name}", width=terminal_width, border_style="blue", show_lines=True) + benchmark_col_width = max(int(terminal_width * 0.4), 40) + # Add columns - split the benchmark test into two columns + table.add_column("Benchmark Module Path", width=benchmark_col_width, style="cyan", overflow="fold") + table.add_column("Test Function", style="magenta", overflow="fold") + table.add_column("Total Time (ms)", justify="right", style="green") + table.add_column("Function Time (ms)", justify="right", style="yellow") + table.add_column("Percentage (%)", justify="right", style="red") + + for benchmark_key, total_time, func_time, percentage in sorted_tests: + # Split the benchmark test into module path and function name + module_path = benchmark_key.module_path + test_function = benchmark_key.function_name + + if total_time == 0.0: + table.add_row( + module_path, + test_function, + "N/A", + "N/A", + "N/A" + ) + else: + table.add_row( + module_path, + test_function, + f"{total_time:.3f}", + f"{func_time:.3f}", + f"{percentage:.2f}" + ) + + # Print the table + console.print(table) + + +def process_benchmark_data( + replay_performance_gain: dict[BenchmarkKey, float], + fto_benchmark_timings: dict[BenchmarkKey, int], + total_benchmark_timings: dict[BenchmarkKey, int] +) -> Optional[ProcessedBenchmarkInfo]: + """Process benchmark data and generate detailed benchmark information. + + Args: + replay_performance_gain: The performance gain from replay + fto_benchmark_timings: Function to optimize benchmark timings + total_benchmark_timings: Total benchmark timings + + Returns: + ProcessedBenchmarkInfo containing processed benchmark details + + """ + if not replay_performance_gain or not fto_benchmark_timings or not total_benchmark_timings: + return None + + benchmark_details = [] + + for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items(): + + total_benchmark_timing = total_benchmark_timings.get(benchmark_key, 0) + + if total_benchmark_timing == 0: + continue # Skip benchmarks with zero timing + + # Calculate expected new benchmark timing + expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + ( + 1 / (replay_performance_gain[benchmark_key] + 1) + ) * og_benchmark_timing + + # Calculate speedup + benchmark_speedup_percent = performance_gain(original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing)) * 100 + + benchmark_details.append( + BenchmarkDetail( + benchmark_name=benchmark_key.module_path, + test_function=benchmark_key.function_name, + original_timing=humanize_runtime(int(total_benchmark_timing)), + expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)), + speedup_percent=benchmark_speedup_percent + ) + ) + + return ProcessedBenchmarkInfo(benchmark_details=benchmark_details) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 6ac4db420..ed0dbd760 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -62,6 +62,10 @@ def parse_args() -> Namespace: ) parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs") parser.add_argument("--version", action="store_true", help="Print the version of codeflash") + parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks") + parser.add_argument( + "--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located." + ) args: Namespace = parser.parse_args() return process_and_validate_cmd_args(args) @@ -109,6 +113,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: supported_keys = [ "module_root", "tests_root", + "benchmarks_root", "test_framework", "ignore_paths", "pytest_cmd", @@ -127,7 +132,12 @@ def process_pyproject_config(args: Namespace) -> Namespace: assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" assert args.tests_root is not None, "--tests-root must be specified" assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory" - + if args.benchmark: + assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark" + assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory" + assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), ( + f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}" + ) if env_utils.get_pr_number() is not None: assert env_utils.ensure_codeflash_api_key(), ( "Codeflash API key not found. When running in a Github Actions Context, provide the " diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 3ef7e2eec..70920eda0 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -49,6 +49,7 @@ class SetupInfo: module_root: str tests_root: str + benchmarks_root: str | None test_framework: str ignore_paths: list[str] formatter: str @@ -125,8 +126,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None: run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path) def should_modify_pyproject_toml() -> bool: - """ - Check if the current directory contains a valid pyproject.toml file with codeflash config + """Check if the current directory contains a valid pyproject.toml file with codeflash config If it does, ask the user if they want to re-configure it. """ from rich.prompt import Confirm @@ -135,7 +135,7 @@ def should_modify_pyproject_toml() -> bool: return True try: config, config_file_path = parse_config_file(pyproject_toml_path) - except Exception as e: + except Exception: return True if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir(): @@ -144,7 +144,7 @@ def should_modify_pyproject_toml() -> bool: return True create_toml = Confirm.ask( - f"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True + "✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True ) return create_toml @@ -244,6 +244,66 @@ def collect_setup_info() -> SetupInfo: ph("cli-test-framework-provided", {"test_framework": test_framework}) + # Get benchmarks root directory + default_benchmarks_subdir = "benchmarks" + create_benchmarks_option = f"okay, create a {default_benchmarks_subdir}{os.path.sep} directory for me!" + no_benchmarks_option = "I don't need benchmarks" + + # Check if benchmarks directory exists inside tests directory + tests_subdirs = [] + if tests_root.exists(): + tests_subdirs = [d.name for d in tests_root.iterdir() if d.is_dir() and not d.name.startswith(".")] + + benchmarks_options = [] + if default_benchmarks_subdir in tests_subdirs: + benchmarks_options.append(default_benchmarks_subdir) + benchmarks_options.extend([d for d in tests_subdirs if d != default_benchmarks_subdir]) + benchmarks_options.append(create_benchmarks_option) + benchmarks_options.append(custom_dir_option) + benchmarks_options.append(no_benchmarks_option) + + benchmarks_answer = inquirer_wrapper( + inquirer.list_input, + message="Where are your benchmarks located? (benchmarks must be a sub directory of your tests root directory)", + choices=benchmarks_options, + default=( + default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0]), + ) + + if benchmarks_answer == create_benchmarks_option: + benchmarks_root = tests_root / default_benchmarks_subdir + benchmarks_root.mkdir(exist_ok=True) + click.echo(f"✅ Created directory {benchmarks_root}{os.path.sep}{LF}") + elif benchmarks_answer == custom_dir_option: + custom_benchmarks_answer = inquirer_wrapper_path( + "path", + message=f"Enter the path to your benchmarks directory inside {tests_root}{os.path.sep} ", + path_type=inquirer.Path.DIRECTORY, + ) + if custom_benchmarks_answer: + benchmarks_root = tests_root / Path(custom_benchmarks_answer["path"]) + else: + apologize_and_exit() + elif benchmarks_answer == no_benchmarks_option: + benchmarks_root = None + else: + benchmarks_root = tests_root / Path(cast(str, benchmarks_answer)) + + # TODO: Implement other benchmark framework options + # if benchmarks_root: + # benchmarks_root = benchmarks_root.relative_to(curdir) + # + # # Ask about benchmark framework + # benchmark_framework_options = ["pytest-benchmark", "asv (Airspeed Velocity)", "custom/other"] + # benchmark_framework = inquirer_wrapper( + # inquirer.list_input, + # message="Which benchmark framework do you use?", + # choices=benchmark_framework_options, + # default=benchmark_framework_options[0], + # carousel=True, + # ) + + formatter = inquirer_wrapper( inquirer.list_input, message="Which code formatter do you use?", @@ -279,6 +339,7 @@ def collect_setup_info() -> SetupInfo: return SetupInfo( module_root=str(module_root), tests_root=str(tests_root), + benchmarks_root = str(benchmarks_root) if benchmarks_root else None, test_framework=cast(str, test_framework), ignore_paths=ignore_paths, formatter=cast(str, formatter), @@ -437,11 +498,19 @@ def install_github_actions(override_formatter_check: bool = False) -> None: return workflows_path.mkdir(parents=True, exist_ok=True) from importlib.resources import files + benchmark_mode = False + if "benchmarks_root" in config: + benchmark_mode = inquirer_wrapper( + inquirer.confirm, + message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? " + " This will show the impact of Codeflash's suggested optimizations on your benchmarks", + default=True, + ) optimize_yml_content = ( files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8") ) - materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root) + materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode) with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file: optimize_yml_file.write(materialized_optimize_yml_content) click.echo(f"{LF}✅ Created GitHub action workflow at {optimize_yaml_path}{LF}") @@ -556,7 +625,7 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str: def customize_codeflash_yaml_content( - optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path + optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False ) -> str: module_path = str(Path(config["module_root"]).relative_to(git_root) / "**") optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path) @@ -587,6 +656,9 @@ def customize_codeflash_yaml_content( # Add codeflash command codeflash_cmd = get_codeflash_github_action_command(dep_manager) + + if benchmark_mode: + codeflash_cmd += " --benchmark" return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd) @@ -608,6 +680,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: codeflash_section["module-root"] = setup_info.module_root codeflash_section["tests-root"] = setup_info.tests_root codeflash_section["test-framework"] = setup_info.test_framework + codeflash_section["benchmarks-root"] = setup_info.benchmarks_root if setup_info.benchmarks_root else "" codeflash_section["ignore-paths"] = setup_info.ignore_paths if setup_info.git_remote not in ["", "origin"]: codeflash_section["git-remote"] = setup_info.git_remote diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index e9ff9a735..79a39168b 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -52,10 +52,10 @@ def parse_config_file( assert isinstance(config, dict) # default values: - path_keys = {"module-root", "tests-root"} - path_list_keys = {"ignore-paths", } + path_keys = ["module-root", "tests-root", "benchmarks-root"] + path_list_keys = ["ignore-paths"] str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"} - bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False} + bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False} list_str_keys = {"formatter-cmds": ["black $file"]} for key in str_keys: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index a234a2827..6e4f744d7 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -106,7 +106,7 @@ def generic_visit(self, node: ast.AST) -> None: @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) class FunctionToOptimize: - """Represents a function that is a candidate for optimization. + """Represent a function that is a candidate for optimization. Attributes ---------- @@ -145,7 +145,6 @@ def qualified_name(self) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" - def get_functions_to_optimize( optimize_all: str | None, replay_test: str | None, @@ -359,9 +358,15 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: for decorator in body_node.decorator_list ): self.is_classmethod = True + elif any( + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list + ): + self.is_staticmethod = True return - else: - # search if the class has a staticmethod with the same name and on the same line number + elif self.line_no: + # If we have line number info, check if class has a static method with the same line number + # This way, if we don't have the class name, we can still find the static method for body_node in node.body: if ( isinstance(body_node, ast.FunctionDef) diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index d7c12d962..1e66c5608 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -1,9 +1,11 @@ -from typing import Union +from __future__ import annotations +from typing import Union, Optional from pydantic import BaseModel from pydantic.dataclasses import dataclass from codeflash.code_utils.time_utils import humanize_runtime +from codeflash.models.models import BenchmarkDetail from codeflash.models.models import TestResults @@ -18,8 +20,9 @@ class PrComment: speedup_pct: str winning_behavioral_test_results: TestResults winning_benchmarking_test_results: TestResults + benchmark_details: Optional[list[BenchmarkDetail]] = None - def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: + def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Optional[list[BenchmarkDetail]]]]: report_table = { test_type.to_name(): result for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items() @@ -36,6 +39,7 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]: "speedup_pct": self.speedup_pct, "loop_count": self.winning_benchmarking_test_results.number_of_loops(), "report_table": report_table, + "benchmark_details": self.benchmark_details if self.benchmark_details else None, } diff --git a/codeflash/models/models.py b/codeflash/models/models.py index a00834cdd..ddaccd16e 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from typing import TYPE_CHECKING from rich.tree import Tree @@ -11,7 +12,7 @@ import enum import re import sys -from collections.abc import Collection, Iterator +from collections.abc import Collection from enum import Enum, IntEnum from pathlib import Path from re import Pattern @@ -22,7 +23,7 @@ from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger -from codeflash.code_utils.code_utils import validate_python_code +from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code from codeflash.code_utils.env_utils import is_end_to_end from codeflash.verification.comparator import comparator @@ -58,28 +59,74 @@ class FunctionSource: def __eq__(self, other: object) -> bool: if not isinstance(other, FunctionSource): return False - return ( - self.file_path == other.file_path - and self.qualified_name == other.qualified_name - and self.fully_qualified_name == other.fully_qualified_name - and self.only_function_name == other.only_function_name - and self.source_code == other.source_code - ) + return (self.file_path == other.file_path and + self.qualified_name == other.qualified_name and + self.fully_qualified_name == other.fully_qualified_name and + self.only_function_name == other.only_function_name and + self.source_code == other.source_code) def __hash__(self) -> int: - return hash( - (self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code) - ) - + return hash((self.file_path, self.qualified_name, self.fully_qualified_name, + self.only_function_name, self.source_code)) class BestOptimization(BaseModel): candidate: OptimizedCandidate helper_functions: list[FunctionSource] runtime: int + replay_performance_gain: Optional[dict[BenchmarkKey,float]] = None winning_behavioral_test_results: TestResults winning_benchmarking_test_results: TestResults + winning_replay_benchmarking_test_results : Optional[TestResults] = None + +@dataclass(frozen=True) +class BenchmarkKey: + module_path: str + function_name: str + + def __str__(self) -> str: + return f"{self.module_path}::{self.function_name}" + +@dataclass +class BenchmarkDetail: + benchmark_name: str + test_function: str + original_timing: str + expected_new_timing: str + speedup_percent: float + + def to_string(self) -> str: + return ( + f"Original timing for {self.benchmark_name}::{self.test_function}: {self.original_timing}\n" + f"Expected new timing for {self.benchmark_name}::{self.test_function}: {self.expected_new_timing}\n" + f"Benchmark speedup for {self.benchmark_name}::{self.test_function}: {self.speedup_percent:.2f}%\n" + ) + def to_dict(self) -> dict[str, any]: + return { + "benchmark_name": self.benchmark_name, + "test_function": self.test_function, + "original_timing": self.original_timing, + "expected_new_timing": self.expected_new_timing, + "speedup_percent": self.speedup_percent + } + +@dataclass +class ProcessedBenchmarkInfo: + benchmark_details: list[BenchmarkDetail] + + def to_string(self) -> str: + if not self.benchmark_details: + return "" + + result = "Benchmark Performance Details:\n" + for detail in self.benchmark_details: + result += detail.to_string() + "\n" + return result + def to_dict(self) -> dict[str, list[dict[str, any]]]: + return { + "benchmark_details": [detail.to_dict() for detail in self.benchmark_details] + } class CodeString(BaseModel): code: Annotated[str, AfterValidator(validate_python_code)] file_path: Optional[Path] = None @@ -104,8 +151,7 @@ class CodeOptimizationContext(BaseModel): read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" helper_functions: list[FunctionSource] - preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] - + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] class CodeContextType(str, Enum): READ_WRITABLE = "READ_WRITABLE" @@ -118,6 +164,7 @@ class OptimizedCandidateResult(BaseModel): best_test_runtime: int behavior_test_results: TestResults benchmarking_test_results: TestResults + replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None optimization_candidate_index: int total_candidate_timing: int @@ -222,6 +269,7 @@ class FunctionParent: class OriginalCodeBaseline(BaseModel): behavioral_test_results: TestResults benchmarking_test_results: TestResults + replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None line_profile_results: dict runtime: int coverage_results: Optional[CoverageData] @@ -299,7 +347,6 @@ def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOpt status=CoverageStatus.NOT_FOUND, ) - @dataclass class FunctionCoverage: """Represents the coverage data for a specific function in a source file.""" @@ -426,6 +473,20 @@ def merge(self, other: TestResults) -> None: raise ValueError(msg) self.test_result_idx[k] = v + original_len + def group_by_benchmarks(self, benchmark_keys:list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path) -> dict[BenchmarkKey, TestResults]: + """Group TestResults by benchmark for calculating improvements for each benchmark.""" + test_results_by_benchmark = defaultdict(TestResults) + benchmark_module_path = {} + for benchmark_key in benchmark_keys: + benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", project_root) + for test_result in self.test_results: + if (test_result.test_type == TestType.REPLAY_TEST): + for benchmark_key, module_path in benchmark_module_path.items(): + if test_result.id.test_module_path.startswith(module_path): + test_results_by_benchmark[benchmark_key].add(test_result) + + return test_results_by_benchmark + def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None: try: return self.test_results[self.test_result_idx[unique_invocation_loop_id]] diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7fa8805c9..49958dc96 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -19,6 +19,7 @@ from rich.tree import Tree from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient +from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import replace_function_definitions_in_module @@ -42,7 +43,6 @@ from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Failure, Success, is_successful from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( @@ -76,8 +76,9 @@ if TYPE_CHECKING: from argparse import Namespace + from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result - from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate + from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate from codeflash.verification.verification_utils import TestConfig @@ -90,7 +91,10 @@ def __init__( function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, function_to_optimize_ast: ast.FunctionDef | None = None, aiservice_client: AiServiceClient | None = None, + function_benchmark_timings: dict[BenchmarkKey, int] | None = None, + total_benchmark_timings: dict[BenchmarkKey, int] | None = None, args: Namespace | None = None, + replay_tests_dir: Path|None = None ) -> None: self.project_root = test_cfg.project_root_path self.test_cfg = test_cfg @@ -113,11 +117,14 @@ def __init__( self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.test_files = TestFiles(test_files=[]) - self.args = args # Check defaults for these self.function_trace_id: str = str(uuid.uuid4()) self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root) + self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} + self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} + self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None + def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment = self.experiment_id is not None logger.debug(f"Function Trace ID: {self.function_trace_id}") @@ -136,8 +143,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: original_helper_code[helper_function_path] = helper_code if has_any_async_functions(code_context.read_writable_code): return Failure("Codeflash does not support async functions in the code to optimize.") - code_print(code_context.read_writable_code) + code_print(code_context.read_writable_code) generated_test_paths = [ get_test_file_path( self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" @@ -261,6 +268,13 @@ def optimize_function(self) -> Result[BestOptimization, str]: best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" ) ) + processed_benchmark_info = None + if self.args.benchmark: + processed_benchmark_info = process_benchmark_data( + replay_performance_gain=best_optimization.replay_performance_gain, + fto_benchmark_timings=self.function_benchmark_timings, + total_benchmark_timings=self.total_benchmark_timings + ) explanation = Explanation( raw_explanation_message=best_optimization.candidate.explanation, winning_behavioral_test_results=best_optimization.winning_behavioral_test_results, @@ -269,6 +283,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: best_runtime_ns=best_optimization.runtime, function_name=function_to_optimize_qualified_name, file_path=self.function_to_optimize.file_path, + benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None ) self.log_successful_optimization(explanation, generated_tests) @@ -362,7 +377,7 @@ def determine_best_candidate( candidates = deque(candidates) # Start a new thread for AI service request, start loop in main thread # check if aiservice request is complete, when it is complete, append result to the candidates list - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: future_line_profile_results = executor.submit( self.aiservice_client.optimize_python_code_line_profiler, source_code=code_context.read_writable_code, @@ -382,8 +397,8 @@ def determine_best_candidate( if done and (future_line_profile_results is not None): line_profile_results = future_line_profile_results.result() candidates.extend(line_profile_results) - original_len+= len(line_profile_results) - logger.info(f"Added {len(line_profile_results)} results from line profiler to candidates, total candidates now: {original_len}") + original_len+= len(candidates) + logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}") future_line_profile_results = None candidate_index += 1 candidate = candidates.popleft() @@ -410,7 +425,7 @@ def determine_best_candidate( ) continue - # Instrument codeflash capture + run_results = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, @@ -434,6 +449,7 @@ def determine_best_candidate( speedup_ratios[candidate.optimization_id] = perf_gain tree = Tree(f"Candidate #{candidate_index} - Runtime Information") + benchmark_tree = None if speedup_critic( candidate_result, original_code_baseline.runtime, best_runtime_until_now ) and quantity_of_tests_critic(candidate_result): @@ -446,13 +462,29 @@ def determine_best_candidate( ) tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X") + replay_perf_gain = {} + if self.args.benchmark: + test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root) + if len(test_results_by_benchmark) > 0: + benchmark_tree = Tree("Speedup percentage on benchmarks:") + for benchmark_key, candidate_test_results in test_results_by_benchmark.items(): + + original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[benchmark_key].total_passed_runtime() + candidate_replay_runtime = candidate_test_results.total_passed_runtime() + replay_perf_gain[benchmark_key] = performance_gain( + original_runtime_ns=original_code_replay_runtime, + optimized_runtime_ns=candidate_replay_runtime, + ) + benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%") best_optimization = BestOptimization( candidate=candidate, helper_functions=code_context.helper_functions, runtime=best_test_runtime, winning_behavioral_test_results=candidate_result.behavior_test_results, + replay_performance_gain=replay_perf_gain if self.args.benchmark else None, winning_benchmarking_test_results=candidate_result.benchmarking_test_results, + winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, ) best_runtime_until_now = best_test_runtime else: @@ -464,6 +496,8 @@ def determine_best_candidate( tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") console.print(tree) + if self.args.benchmark and benchmark_tree: + console.print(benchmark_tree) console.rule() self.write_code_and_helpers( @@ -507,7 +541,8 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests: ) console.print(Group(explanation_panel, tests_panel)) - console.print(explanation_panel) + else: + console.print(explanation_panel) ph( "cli-optimize-success", @@ -664,6 +699,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi unique_instrumented_test_files.add(new_behavioral_test_path) unique_instrumented_test_files.add(new_perf_test_path) + if not self.test_files.get_by_original_file_path(path_obj_test_file): self.test_files.add( TestFile( @@ -675,6 +711,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi tests_in_file=[t.tests_in_file for t in tests_in_file_list], ) ) + logger.info( f"Discovered {existing_test_files_count} existing unit test file" f"{'s' if existing_test_files_count != 1 else ''}, {replay_test_files_count} replay test file" @@ -865,7 +902,6 @@ def establish_original_code_baseline( enable_coverage=False, code_context=code_context, ) - else: benchmarking_results = TestResults() start_time: float = time.time() @@ -920,11 +956,15 @@ def establish_original_code_baseline( ) console.rule() logger.debug(f"Total original code runtime (ns): {total_timing}") + + if self.args.benchmark: + replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root) return Success( ( OriginalCodeBaseline( behavioral_test_results=behavioral_results, benchmarking_test_results=benchmarking_results, + replay_benchmarking_test_results = replay_benchmarking_test_results if self.args.benchmark else None, runtime=total_timing, coverage_results=coverage_results, line_profile_results=line_profile_results, @@ -954,8 +994,6 @@ def run_optimized_candidate( test_env["PYTHONPATH"] += os.pathsep + str(self.project_root) get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) - # Instrument codeflash capture candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") candidate_helper_code = {} @@ -986,7 +1024,6 @@ def run_optimized_candidate( ) ) console.rule() - if compare_test_results(baseline_results.behavioral_test_results, candidate_behavior_results): logger.info("Test results matched!") console.rule() @@ -1039,12 +1076,17 @@ def run_optimized_candidate( console.rule() logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") + if self.args.benchmark: + candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root) + for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items(): + logger.debug(f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}") return Success( OptimizedCandidateResult( max_loop_count=loop_count, best_test_runtime=total_candidate_timing, behavior_test_results=candidate_behavior_results, benchmarking_test_results=candidate_benchmarking_results, + replay_benchmarking_test_results = candidate_replay_benchmarking_results if self.args.benchmark else None, optimization_candidate_index=optimization_candidate_index, total_candidate_timing=total_candidate_timing, ) @@ -1086,8 +1128,8 @@ def run_and_parse_tests( pytest_cmd=self.test_cfg.pytest_cmd, pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, pytest_target_runtime_seconds=testing_time, - pytest_min_loops=pytest_min_loops, - pytest_max_loops=pytest_min_loops, + pytest_min_loops=1, + pytest_max_loops=1, test_framework=self.test_cfg.test_framework, line_profiler_output_file=line_profiler_output_file, ) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 67fada646..7f42b58c4 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -5,10 +5,16 @@ import time import shutil import tempfile +from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient +from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator +from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin +from codeflash.benchmarking.replay_test import generate_replay_test +from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest +from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import normalize_code, normalize_node @@ -17,7 +23,7 @@ from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.functions_to_optimize import get_functions_to_optimize from codeflash.either import is_successful -from codeflash.models.models import TestType, ValidCode +from codeflash.models.models import BenchmarkKey, TestType, ValidCode from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig @@ -39,18 +45,21 @@ def __init__(self, args: Namespace) -> None: project_root_path=args.project_root, test_framework=args.test_framework, pytest_cmd=args.pytest_cmd, + benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None, ) self.aiservice_client = AiServiceClient() self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None - + self.replay_tests_dir = None def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.FunctionDef | None = None, function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, function_to_optimize_source_code: str | None = "", + function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, + total_benchmark_timings: dict[BenchmarkKey, float] | None = None, ) -> FunctionOptimizer: return FunctionOptimizer( function_to_optimize=function_to_optimize, @@ -60,6 +69,9 @@ def create_function_optimizer( function_to_optimize_ast=function_to_optimize_ast, aiservice_client=self.aiservice_client, args=self.args, + function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None, + total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None, + replay_tests_dir = self.replay_tests_dir ) def run(self) -> None: @@ -71,6 +83,8 @@ def run(self) -> None: function_optimizer = None file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] num_optimizable_functions: int + + # discover functions (file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize( optimize_all=self.args.all, replay_test=self.args.replay_test, @@ -81,7 +95,42 @@ def run(self) -> None: project_root=self.args.project_root, module_root=self.args.module_root, ) + function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} + total_benchmark_timings: dict[BenchmarkKey, int] = {} + if self.args.benchmark: + with progress_bar( + f"Running benchmarks in {self.args.benchmarks_root}", + transient=True, + ): + # Insert decorator + file_path_to_source_code = defaultdict(str) + for file in file_to_funcs_to_optimize: + with file.open("r", encoding="utf8") as f: + file_path_to_source_code[file] = f.read() + try: + instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) + trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace" + if trace_file.exists(): + trace_file.unlink() + self.replay_tests_dir = Path(tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root)) + trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark + replay_count = generate_replay_test(trace_file, self.replay_tests_dir) + if replay_count == 0: + logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization") + else: + function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) + total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) + function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + print_benchmark_table(function_to_results) + except Exception as e: + logger.info(f"Error while tracing existing benchmarks: {e}") + logger.info("Information on existing benchmarks will not be available for this run.") + finally: + # Restore original source code + for file in file_path_to_source_code: + with file.open("w", encoding="utf8") as f: + f.write(file_path_to_source_code[file]) optimizations_found: int = 0 function_iterator_count: int = 0 if self.args.test_framework == "pytest": @@ -103,6 +152,7 @@ def run(self) -> None: console.rule() ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) + for original_module_path in file_to_funcs_to_optimize: logger.info(f"Examining file {original_module_path!s}…") console.rule() @@ -159,12 +209,19 @@ def run(self) -> None: f"Skipping optimization." ) continue - function_optimizer = self.create_function_optimizer( - function_to_optimize, - function_to_optimize_ast, - function_to_tests, - validated_original_code[original_module_path].source_code, + qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root( + self.args.project_root ) + if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings: + function_optimizer = self.create_function_optimizer( + function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings + ) + else: + function_optimizer = self.create_function_optimizer( + function_to_optimize, function_to_optimize_ast, function_to_tests, + validated_original_code[original_module_path].source_code + ) + best_optimization = function_optimizer.optimize_function() if is_successful(best_optimization): optimizations_found += 1 @@ -189,6 +246,10 @@ def run(self) -> None: test_file.instrumented_behavior_file_path.unlink(missing_ok=True) if function_optimizer.test_cfg.concolic_test_root_dir: shutil.rmtree(function_optimizer.test_cfg.concolic_test_root_dir, ignore_errors=True) + if self.args.benchmark: + if self.replay_tests_dir.exists(): + shutil.rmtree(self.replay_tests_dir, ignore_errors=True) + trace_file.unlink(missing_ok=True) if hasattr(get_run_tmp_file, "tmpdir"): get_run_tmp_file.tmpdir.cleanup() diff --git a/codeflash/picklepatch/__init__.py b/codeflash/picklepatch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/picklepatch/pickle_patcher.py b/codeflash/picklepatch/pickle_patcher.py new file mode 100644 index 000000000..cfedd28fd --- /dev/null +++ b/codeflash/picklepatch/pickle_patcher.py @@ -0,0 +1,346 @@ +"""PicklePatcher - A utility for safely pickling objects with unpicklable components. + +This module provides functions to recursively pickle objects, replacing unpicklable +components with placeholders that provide informative errors when accessed. +""" + +import pickle +import types + +import dill + +from .pickle_placeholder import PicklePlaceholder + + +class PicklePatcher: + """A utility class for safely pickling objects with unpicklable components. + + This class provides methods to recursively pickle objects, replacing any + components that can't be pickled with placeholder objects. + """ + + # Class-level cache of unpicklable types + _unpicklable_types = set() + + @staticmethod + def dumps(obj, protocol=None, max_depth=100, **kwargs): + """Safely pickle an object, replacing unpicklable parts with placeholders. + + Args: + obj: The object to pickle + protocol: The pickle protocol version to use + max_depth: Maximum recursion depth + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + bytes: Pickled data with placeholders for unpicklable objects + """ + return PicklePatcher._recursive_pickle(obj, max_depth, path=[], protocol=protocol, **kwargs) + + @staticmethod + def loads(pickled_data): + """Unpickle data that may contain placeholders. + + Args: + pickled_data: Pickled data with possible placeholders + + Returns: + The unpickled object with placeholders for unpicklable parts + """ + try: + # We use dill for loading since it can handle everything pickle can + return dill.loads(pickled_data) + except Exception as e: + raise + + @staticmethod + def _create_placeholder(obj, error_msg, path): + """Create a placeholder for an unpicklable object. + + Args: + obj: The original unpicklable object + error_msg: Error message explaining why it couldn't be pickled + path: Path to this object in the object graph + + Returns: + PicklePlaceholder: A placeholder object + """ + obj_type = type(obj) + try: + obj_str = str(obj)[:100] if hasattr(obj, "__str__") else f"" + except: + obj_str = f"" + + print(f"Creating placeholder for {obj_type.__name__} at path {'->'.join(path) or 'root'}: {error_msg}") + + placeholder = PicklePlaceholder( + obj_type.__name__, + obj_str, + error_msg, + path + ) + + # Add this type to our known unpicklable types cache + PicklePatcher._unpicklable_types.add(obj_type) + return placeholder + + @staticmethod + def _pickle(obj, path=None, protocol=None, **kwargs): + """Try to pickle an object using pickle first, then dill. If both fail, create a placeholder. + + Args: + obj: The object to pickle + path: Path to this object in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + tuple: (success, result) where success is a boolean and result is either: + - Pickled bytes if successful + - Error message if not successful + """ + # Try standard pickle first + try: + return True, pickle.dumps(obj, protocol=protocol, **kwargs) + except (pickle.PickleError, TypeError, AttributeError, ValueError) as e: + # Then try dill (which is more powerful) + try: + return True, dill.dumps(obj, protocol=protocol, **kwargs) + except (dill.PicklingError, TypeError, AttributeError, ValueError) as e: + return False, str(e) + + @staticmethod + def _recursive_pickle(obj, max_depth, path=None, protocol=None, **kwargs): + """Recursively try to pickle an object, replacing unpicklable parts with placeholders. + + Args: + obj: The object to pickle + max_depth: Maximum recursion depth + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + bytes: Pickled data with placeholders for unpicklable objects + """ + if path is None: + path = [] + + obj_type = type(obj) + + # Check if this type is known to be unpicklable + if obj_type in PicklePatcher._unpicklable_types: + placeholder = PicklePatcher._create_placeholder( + obj, + "Known unpicklable type", + path + ) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + # Check for max depth + if max_depth <= 0: + placeholder = PicklePatcher._create_placeholder( + obj, + "Max recursion depth exceeded", + path + ) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + # Try standard pickling + success, result = PicklePatcher._pickle(obj, path, protocol, **kwargs) + if success: + return result + + error_msg = result # Error message from pickling attempt + + # Handle different container types + if isinstance(obj, dict): + return PicklePatcher._handle_dict(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) + elif isinstance(obj, (list, tuple, set)): + return PicklePatcher._handle_sequence(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) + elif hasattr(obj, "__dict__"): + result = PicklePatcher._handle_object(obj, max_depth, error_msg, path, protocol=protocol, **kwargs) + + # If this was a failure, add the type to the cache + unpickled = dill.loads(result) + if isinstance(unpickled, PicklePlaceholder): + PicklePatcher._unpicklable_types.add(obj_type) + return result + + # For other unpicklable objects, use a placeholder + placeholder = PicklePatcher._create_placeholder(obj, error_msg, path) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + @staticmethod + def _handle_dict(obj_dict, max_depth, error_msg, path, protocol=None, **kwargs): + """Handle pickling for dictionary objects. + + Args: + obj_dict: The dictionary to pickle + max_depth: Maximum recursion depth + error_msg: Error message from the original pickling attempt + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + bytes: Pickled data with placeholders for unpicklable objects + """ + if not isinstance(obj_dict, dict): + placeholder = PicklePatcher._create_placeholder( + obj_dict, + f"Expected a dictionary, got {type(obj_dict).__name__}", + path + ) + return dill.dumps(placeholder, protocol=protocol, **kwargs) + + result = {} + + for key, value in obj_dict.items(): + # Process the key + key_success, key_result = PicklePatcher._pickle(key, path, protocol, **kwargs) + if key_success: + key_result = key + else: + # If the key can't be pickled, use a string representation + try: + key_str = str(key)[:50] + except: + key_str = f"" + key_result = f"" + + # Process the value + value_path = path + [f"[{repr(key)[:20]}]"] + value_success, value_bytes = PicklePatcher._pickle(value, value_path, protocol, **kwargs) + + if value_success: + value_result = value + else: + # Try recursive pickling for the value + try: + value_bytes = PicklePatcher._recursive_pickle( + value, max_depth - 1, value_path, protocol=protocol, **kwargs + ) + value_result = dill.loads(value_bytes) + except Exception as inner_e: + value_result = PicklePatcher._create_placeholder( + value, + str(inner_e), + value_path + ) + + result[key_result] = value_result + + return dill.dumps(result, protocol=protocol, **kwargs) + + @staticmethod + def _handle_sequence(obj_seq, max_depth, error_msg, path, protocol=None, **kwargs): + """Handle pickling for sequence types (list, tuple, set). + + Args: + obj_seq: The sequence to pickle + max_depth: Maximum recursion depth + error_msg: Error message from the original pickling attempt + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + bytes: Pickled data with placeholders for unpicklable objects + """ + result = [] + + for i, item in enumerate(obj_seq): + item_path = path + [f"[{i}]"] + + # Try to pickle the item directly + success, _ = PicklePatcher._pickle(item, item_path, protocol, **kwargs) + if success: + result.append(item) + continue + + # If we couldn't pickle directly, try recursively + try: + item_bytes = PicklePatcher._recursive_pickle( + item, max_depth - 1, item_path, protocol=protocol, **kwargs + ) + result.append(dill.loads(item_bytes)) + except Exception as inner_e: + # If recursive pickling fails, use a placeholder + placeholder = PicklePatcher._create_placeholder( + item, + str(inner_e), + item_path + ) + result.append(placeholder) + + # Convert back to the original type + if isinstance(obj_seq, tuple): + result = tuple(result) + elif isinstance(obj_seq, set): + # Try to create a set from the result + try: + result = set(result) + except Exception: + # If we can't create a set (unhashable items), keep it as a list + pass + + return dill.dumps(result, protocol=protocol, **kwargs) + + @staticmethod + def _handle_object(obj, max_depth, error_msg, path, protocol=None, **kwargs): + """Handle pickling for custom objects with __dict__. + + Args: + obj: The object to pickle + max_depth: Maximum recursion depth + error_msg: Error message from the original pickling attempt + path: Current path in the object graph + protocol: The pickle protocol version to use + **kwargs: Additional arguments for pickle/dill.dumps + + Returns: + bytes: Pickled data with placeholders for unpicklable objects + """ + # Try to create a new instance of the same class + try: + # First try to create an empty instance + new_obj = object.__new__(type(obj)) + + # Handle __dict__ attributes if they exist + if hasattr(obj, "__dict__"): + for attr_name, attr_value in obj.__dict__.items(): + attr_path = path + [attr_name] + + # Try to pickle directly first + success, _ = PicklePatcher._pickle(attr_value, attr_path, protocol, **kwargs) + if success: + setattr(new_obj, attr_name, attr_value) + continue + + # If direct pickling fails, try recursive pickling + try: + attr_bytes = PicklePatcher._recursive_pickle( + attr_value, max_depth - 1, attr_path, protocol=protocol, **kwargs + ) + setattr(new_obj, attr_name, dill.loads(attr_bytes)) + except Exception as inner_e: + # Use placeholder for unpicklable attribute + placeholder = PicklePatcher._create_placeholder( + attr_value, + str(inner_e), + attr_path + ) + setattr(new_obj, attr_name, placeholder) + + # Try to pickle the patched object + success, result = PicklePatcher._pickle(new_obj, path, protocol, **kwargs) + if success: + return result + # Fall through to placeholder creation + except Exception: + pass # Fall through to placeholder creation + + # If we get here, just use a placeholder + placeholder = PicklePatcher._create_placeholder(obj, error_msg, path) + return dill.dumps(placeholder, protocol=protocol, **kwargs) \ No newline at end of file diff --git a/codeflash/picklepatch/pickle_placeholder.py b/codeflash/picklepatch/pickle_placeholder.py new file mode 100644 index 000000000..0d730dabb --- /dev/null +++ b/codeflash/picklepatch/pickle_placeholder.py @@ -0,0 +1,71 @@ +class PicklePlaceholderAccessError(Exception): + """Custom exception raised when attempting to access an unpicklable object.""" + + + +class PicklePlaceholder: + """A placeholder for an object that couldn't be pickled. + + When unpickled, any attempt to access attributes or call methods on this + placeholder will raise a PicklePlaceholderAccessError. + """ + + def __init__(self, obj_type, obj_str, error_msg, path=None): + """Initialize a placeholder for an unpicklable object. + + Args: + obj_type (str): The type name of the original object + obj_str (str): String representation of the original object + error_msg (str): The error message that occurred during pickling + path (list, optional): Path to this object in the original object graph + + """ + # Store these directly in __dict__ to avoid __getattr__ recursion + self.__dict__["obj_type"] = obj_type + self.__dict__["obj_str"] = obj_str + self.__dict__["error_msg"] = error_msg + self.__dict__["path"] = path if path is not None else [] + + def __getattr__(self, name): + """Raise a custom error when any attribute is accessed.""" + path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object" + raise PicklePlaceholderAccessError( + f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. " + f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}" + ) + + def __setattr__(self, name, value): + """Prevent setting attributes.""" + self.__getattr__(name) # This will raise our custom error + + def __call__(self, *args, **kwargs): + """Raise a custom error when the object is called.""" + path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object" + raise PicklePlaceholderAccessError( + f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. " + f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}" + ) + + def __repr__(self): + """Return a string representation of the placeholder.""" + try: + path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root" + return f"" + except: + return "" + + def __str__(self): + """Return a string representation of the placeholder.""" + return self.__repr__() + + def __reduce__(self): + """Make sure pickling of the placeholder itself works correctly.""" + return ( + PicklePlaceholder, + ( + self.__dict__["obj_type"], + self.__dict__["obj_str"], + self.__dict__["error_msg"], + self.__dict__["path"] + ) + ) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index e2d4da13c..da0c61961 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -77,6 +77,7 @@ def check_create_pr( speedup_pct=explanation.speedup_pct, winning_behavioral_test_results=explanation.winning_behavioral_test_results, winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + benchmark_details=explanation.benchmark_details ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, @@ -123,6 +124,7 @@ def check_create_pr( speedup_pct=explanation.speedup_pct, winning_behavioral_test_results=explanation.winning_behavioral_test_results, winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + benchmark_details=explanation.benchmark_details ), existing_tests=existing_tests_source, generated_tests=generated_original_test_source, diff --git a/codeflash/result/explanation.py b/codeflash/result/explanation.py index 8a2f8f81d..c6e1fb9dc 100644 --- a/codeflash/result/explanation.py +++ b/codeflash/result/explanation.py @@ -1,9 +1,16 @@ +from __future__ import annotations + +import shutil +from io import StringIO from pathlib import Path +from typing import Optional, cast from pydantic.dataclasses import dataclass +from rich.console import Console +from rich.table import Table from codeflash.code_utils.time_utils import humanize_runtime -from codeflash.models.models import TestResults +from codeflash.models.models import BenchmarkDetail, TestResults @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) @@ -15,6 +22,7 @@ class Explanation: best_runtime_ns: int function_name: str file_path: Path + benchmark_details: Optional[list[BenchmarkDetail]] = None @property def perf_improvement_line(self) -> str: @@ -37,16 +45,55 @@ def to_console_string(self) -> str: # TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts original_runtime_human = humanize_runtime(self.original_runtime_ns) best_runtime_human = humanize_runtime(self.best_runtime_ns) + benchmark_info = "" + + if self.benchmark_details: + # Get terminal width (or use a reasonable default if detection fails) + try: + terminal_width = int(shutil.get_terminal_size().columns * 0.9) + except Exception: + terminal_width = 200 # Fallback width + + # Create a rich table for better formatting + table = Table(title="Benchmark Performance Details", width=terminal_width, show_lines=True) + + # Add columns - split Benchmark File and Function into separate columns + # Using proportional width for benchmark file column (40% of terminal width) + benchmark_col_width = max(int(terminal_width * 0.4), 40) + table.add_column("Benchmark Module Path", style="cyan", width=benchmark_col_width, overflow="fold") + table.add_column("Test Function", style="cyan", overflow="fold") + table.add_column("Original Runtime", style="magenta", justify="right") + table.add_column("Expected New Runtime", style="green", justify="right") + table.add_column("Speedup", style="red", justify="right") + + # Add rows with split data + for detail in self.benchmark_details: + # Split the benchmark name and test function + benchmark_name = detail.benchmark_name + test_function = detail.test_function + + table.add_row( + benchmark_name, + test_function, + f"{detail.original_timing}", + f"{detail.expected_new_timing}", + f"{detail.speedup_percent:.2f}%" + ) + # Convert table to string + string_buffer = StringIO() + console = Console(file=string_buffer, width=terminal_width) + console.print(table) + benchmark_info = cast(StringIO, console.file).getvalue() + "\n" # Cast for mypy return ( - f"Optimized {self.function_name} in {self.file_path}\n" - f"{self.perf_improvement_line}\n" - f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" - + "Explanation:\n" - + self.raw_explanation_message - + " \n\n" - + "The new optimized code was tested for correctness. The results are listed below.\n" - + f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n" + f"Optimized {self.function_name} in {self.file_path}\n" + f"{self.perf_improvement_line}\n" + f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" + + (benchmark_info if benchmark_info else "") + + self.raw_explanation_message + + " \n\n" + + "The new optimized code was tested for correctness. The results are listed below.\n" + + f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n" ) def explanation_message(self) -> str: diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index f047d5b3c..0ebd2cc7d 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -10,6 +10,7 @@ import sentry_sdk from codeflash.cli_cmds.console import logger +from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError try: import numpy as np @@ -90,6 +91,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: return True return math.isclose(orig, new) if isinstance(orig, BaseException): + if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError): + # If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object. + # The test results should be rejected as the behavior of the unpickleable object is unknown. + logger.debug("Unable to verify behavior of unpickleable object in replay test") + return False # if str(orig) != str(new): # return False # compare the attributes of the two exception objects to determine if they are equivalent. diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 91ed31757..43cb78770 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -75,3 +75,4 @@ class TestConfig: # or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path) concolic_test_root_dir: Optional[Path] = None pytest_cmd: str = "pytest" + benchmark_tests_root: Optional[Path] = None diff --git a/pyproject.toml b/pyproject.toml index 8a3d0d523..bf1718d33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,7 @@ types-openpyxl = ">=3.1.5.20241020" types-regex = ">=2024.9.11.20240912" types-python-dateutil = ">=2.9.0.20241003" pytest-cov = "^6.0.0" +pytest-benchmark = ">=5.1.0" types-gevent = "^24.11.0.20241230" types-greenlet = "^3.1.0.20241221" types-pexpect = "^4.9.0.20241208" @@ -219,6 +220,7 @@ initial-content = """ [tool.codeflash] module-root = "codeflash" tests-root = "tests" +benchmarks-root = "tests/benchmarks" test-framework = "pytest" formatter-cmds = [ "uvx ruff check --exit-zero --fix $file", diff --git a/tests/benchmarks/test_benchmark_code_extract_code_context.py b/tests/benchmarks/test_benchmark_code_extract_code_context.py new file mode 100644 index 000000000..122276408 --- /dev/null +++ b/tests/benchmarks/test_benchmark_code_extract_code_context.py @@ -0,0 +1,31 @@ +from argparse import Namespace +from pathlib import Path + +from codeflash.context.code_context_extractor import get_code_optimization_context +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import FunctionParent +from codeflash.optimization.optimizer import Optimizer + + +def test_benchmark_extract(benchmark)->None: + file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash" + opt = Optimizer( + Namespace( + project_root=file_path.resolve(), + disable_telemetry=True, + tests_root=(file_path / "tests").resolve(), + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path.cwd(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="replace_function_and_helpers_with_optimized_code", + file_path=file_path / "optimization" / "function_optimizer.py", + parents=[FunctionParent(name="FunctionOptimizer", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + benchmark(get_code_optimization_context,function_to_optimize, opt.args.project_root) diff --git a/tests/benchmarks/test_benchmark_discover_unit_tests.py b/tests/benchmarks/test_benchmark_discover_unit_tests.py new file mode 100644 index 000000000..4b05f663b --- /dev/null +++ b/tests/benchmarks/test_benchmark_discover_unit_tests.py @@ -0,0 +1,26 @@ +from pathlib import Path + +from codeflash.discovery.discover_unit_tests import discover_unit_tests +from codeflash.verification.verification_utils import TestConfig + + +def test_benchmark_code_to_optimize_test_discovery(benchmark) -> None: + project_path = Path(__file__).parent.parent.parent.resolve() / "code_to_optimize" + tests_path = project_path / "tests" / "pytest" + test_config = TestConfig( + tests_root=tests_path, + project_root_path=project_path, + test_framework="pytest", + tests_project_rootdir=tests_path.parent, + ) + benchmark(discover_unit_tests, test_config) +def test_benchmark_codeflash_test_discovery(benchmark) -> None: + project_path = Path(__file__).parent.parent.parent.resolve() / "codeflash" + tests_path = project_path / "tests" + test_config = TestConfig( + tests_root=tests_path, + project_root_path=project_path, + test_framework="pytest", + tests_project_rootdir=tests_path.parent, + ) + benchmark(discover_unit_tests, test_config) diff --git a/tests/benchmarks/test_benchmark_merge_test_results.py b/tests/benchmarks/test_benchmark_merge_test_results.py new file mode 100644 index 000000000..f0c126f75 --- /dev/null +++ b/tests/benchmarks/test_benchmark_merge_test_results.py @@ -0,0 +1,71 @@ +from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType +from codeflash.verification.parse_test_output import merge_test_results + + +def generate_test_invocations(count=100): + """Generate a set number of test invocations for benchmarking.""" + test_results_xml = TestResults() + test_results_bin = TestResults() + + # Generate test invocations in a loop + for i in range(count): + iteration_id = str(i * 3 + 5) # Generate unique iteration IDs + + # XML results - some with None runtime + test_results_xml.add( + FunctionTestInvocation( + id=InvocationId( + test_module_path="code_to_optimize.tests.unittest.test_bubble_sort", + test_class_name="TestPigLatin", + test_function_name="test_sort", + function_getting_tested="sorter", + iteration_id=iteration_id, + ), + file_name="/tmp/tests/unittest/test_bubble_sort__perfinstrumented.py", + did_pass=True, + runtime=None if i % 3 == 0 else i * 100, # Vary runtime values + test_framework="unittest", + test_type=TestType.EXISTING_UNIT_TEST, + return_value=None, + timed_out=False, + loop_index=i, + ) + ) + + # Binary results - with actual runtime values + test_results_bin.add( + FunctionTestInvocation( + id=InvocationId( + test_module_path="code_to_optimize.tests.unittest.test_bubble_sort", + test_class_name="TestPigLatin", + test_function_name="test_sort", + function_getting_tested="sorter", + iteration_id=iteration_id, + ), + file_name="/tmp/tests/unittest/test_bubble_sort__perfinstrumented.py", + did_pass=True, + runtime=500 + i * 20, # Generate varying runtime values + test_framework="unittest", + test_type=TestType.EXISTING_UNIT_TEST, + return_value=None, + timed_out=False, + loop_index=i, + ) + ) + + return test_results_xml, test_results_bin + + +def run_merge_benchmark(count=100): + test_results_xml, test_results_bin = generate_test_invocations(count) + + # Perform the merge operation that will be benchmarked + merge_test_results( + xml_test_results=test_results_xml, + bin_test_results=test_results_bin, + test_framework="unittest" + ) + + +def test_benchmark_merge_test_results(benchmark): + benchmark(run_merge_benchmark, 1000) # Default to 100 test invocations diff --git a/tests/scripts/end_to_end_test_benchmark_sort.py b/tests/scripts/end_to_end_test_benchmark_sort.py new file mode 100644 index 000000000..64aabe384 --- /dev/null +++ b/tests/scripts/end_to_end_test_benchmark_sort.py @@ -0,0 +1,26 @@ +import os +import pathlib + +from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries + + +def run_test(expected_improvement_pct: int) -> bool: + cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve() + config = TestConfig( + file_path=pathlib.Path("bubble_sort.py"), + function_name="sorter", + benchmarks_root=cwd / "tests" / "pytest" / "benchmarks", + test_framework="pytest", + min_improvement_x=1.0, + coverage_expectations=[ + CoverageExpectation( + function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10] + ) + ], + ) + + return run_codeflash_command(cwd, config, expected_improvement_pct) + + +if __name__ == "__main__": + exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 5)))) diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index fda917020..d050f50e9 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -26,6 +26,7 @@ class TestConfig: min_improvement_x: float = 0.1 trace_mode: bool = False coverage_expectations: list[CoverageExpectation] = field(default_factory=list) + benchmarks_root: Optional[pathlib.Path] = None def clear_directory(directory_path: str | pathlib.Path) -> None: @@ -85,8 +86,8 @@ def run_codeflash_command( path_to_file = cwd / config.file_path file_contents = path_to_file.read_text("utf-8") test_root = cwd / "tests" / (config.test_framework or "") - command = build_command(cwd, config, test_root) + command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None) process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy() ) @@ -116,7 +117,7 @@ def run_codeflash_command( return validated -def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path) -> list[str]: +def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path, benchmarks_root:pathlib.Path|None = None) -> list[str]: python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py" base_command = ["python", python_path, "--file", config.file_path, "--no-pr"] @@ -127,7 +128,8 @@ def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path base_command.extend( ["--test-framework", config.test_framework, "--tests-root", str(test_root), "--module-root", str(cwd)] ) - + if benchmarks_root: + base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)]) return base_command diff --git a/tests/test_codeflash_trace_decorator.py b/tests/test_codeflash_trace_decorator.py new file mode 100644 index 000000000..37234d85a --- /dev/null +++ b/tests/test_codeflash_trace_decorator.py @@ -0,0 +1,15 @@ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +from pathlib import Path +from codeflash.code_utils.code_utils import get_run_tmp_file + +@codeflash_trace +def example_function(arr): + arr.sort() + return arr + + +def test_codeflash_trace_decorator(): + arr = [3, 1, 2] + result = example_function(arr) + # cleanup test trace file using Path + assert result == [1, 2, 3] diff --git a/tests/test_instrument_codeflash_trace.py b/tests/test_instrument_codeflash_trace.py new file mode 100644 index 000000000..38a6381e2 --- /dev/null +++ b/tests/test_instrument_codeflash_trace.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path + +from codeflash.benchmarking.instrument_codeflash_trace import add_codeflash_decorator_to_code, \ + instrument_codeflash_trace_decorator +from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize + + +def test_add_decorator_to_normal_function() -> None: + """Test adding decorator to a normal function.""" + code = """ +def normal_function(): + return "Hello, World!" +""" + + fto = FunctionToOptimize( + function_name="normal_function", + file_path=Path("dummy_path.py"), + parents=[] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +@codeflash_trace +def normal_function(): + return "Hello, World!" +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_to_normal_method() -> None: + """Test adding decorator to a normal method.""" + code = """ +class TestClass: + def normal_method(self): + return "Hello from method" +""" + + fto = FunctionToOptimize( + function_name="normal_method", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +class TestClass: + @codeflash_trace + def normal_method(self): + return "Hello from method" +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_to_classmethod() -> None: + """Test adding decorator to a classmethod.""" + code = """ +class TestClass: + @classmethod + def class_method(cls): + return "Hello from classmethod" +""" + + fto = FunctionToOptimize( + function_name="class_method", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +class TestClass: + @classmethod + @codeflash_trace + def class_method(cls): + return "Hello from classmethod" +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_to_staticmethod() -> None: + """Test adding decorator to a staticmethod.""" + code = """ +class TestClass: + @staticmethod + def static_method(): + return "Hello from staticmethod" +""" + + fto = FunctionToOptimize( + function_name="static_method", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +class TestClass: + @staticmethod + @codeflash_trace + def static_method(): + return "Hello from staticmethod" +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_to_init_function() -> None: + """Test adding decorator to an __init__ function.""" + code = """ +class TestClass: + def __init__(self, value): + self.value = value +""" + + fto = FunctionToOptimize( + function_name="__init__", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +class TestClass: + @codeflash_trace + def __init__(self, value): + self.value = value +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_with_multiple_decorators() -> None: + """Test adding decorator to a function with multiple existing decorators.""" + code = """ +class TestClass: + @property + @other_decorator + def property_method(self): + return self._value +""" + + fto = FunctionToOptimize( + function_name="property_method", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +class TestClass: + @property + @other_decorator + @codeflash_trace + def property_method(self): + return self._value +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_to_function_in_multiple_classes() -> None: + """Test that only the right class's method gets the decorator.""" + code = """ +class TestClass: + def test_method(self): + return "This should get decorated" + +class OtherClass: + def test_method(self): + return "This should NOT get decorated" +""" + + fto = FunctionToOptimize( + function_name="test_method", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +class TestClass: + @codeflash_trace + def test_method(self): + return "This should get decorated" + +class OtherClass: + def test_method(self): + return "This should NOT get decorated" +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_to_nonexistent_function() -> None: + """Test that code remains unchanged when function doesn't exist.""" + code = """ +def existing_function(): + return "This exists" +""" + + fto = FunctionToOptimize( + function_name="nonexistent_function", + file_path=Path("dummy_path.py"), + parents=[] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + # Code should remain unchanged + assert modified_code.strip() == code.strip() + + +def test_add_decorator_to_multiple_functions() -> None: + """Test adding decorator to multiple functions.""" + code = """ +def function_one(): + return "First function" + +class TestClass: + def method_one(self): + return "First method" + + def method_two(self): + return "Second method" + +def function_two(): + return "Second function" +""" + + functions_to_optimize = [ + FunctionToOptimize( + function_name="function_one", + file_path=Path("dummy_path.py"), + parents=[] + ), + FunctionToOptimize( + function_name="method_two", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ), + FunctionToOptimize( + function_name="function_two", + file_path=Path("dummy_path.py"), + parents=[] + ) + ] + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=functions_to_optimize + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +@codeflash_trace +def function_one(): + return "First function" + +class TestClass: + def method_one(self): + return "First method" + + @codeflash_trace + def method_two(self): + return "Second method" + +@codeflash_trace +def function_two(): + return "Second function" +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_instrument_codeflash_trace_decorator_single_file() -> None: + """Test instrumenting codeflash trace decorator on a single file.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a test Python file + test_file_path = Path(temp_dir) / "test_module.py" + test_file_content = """ +def function_one(): + return "First function" + +class TestClass: + def method_one(self): + return "First method" + + def method_two(self): + return "Second method" + +def function_two(): + return "Second function" +""" + test_file_path.write_text(test_file_content, encoding="utf-8") + + # Define functions to optimize + functions_to_optimize = [ + FunctionToOptimize( + function_name="function_one", + file_path=test_file_path, + parents=[] + ), + FunctionToOptimize( + function_name="method_two", + file_path=test_file_path, + parents=[FunctionParent(name="TestClass", type="ClassDef")] + ) + ] + + # Execute the function being tested + instrument_codeflash_trace_decorator({test_file_path: functions_to_optimize}) + + # Read the modified file + modified_content = test_file_path.read_text(encoding="utf-8") + + # Define expected content (with isort applied) + expected_content = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace + + +@codeflash_trace +def function_one(): + return "First function" + +class TestClass: + def method_one(self): + return "First method" + + @codeflash_trace + def method_two(self): + return "Second method" + +def function_two(): + return "Second function" +""" + + # Compare the modified content with expected content + assert modified_content.strip() == expected_content.strip() + + +def test_instrument_codeflash_trace_decorator_multiple_files() -> None: + """Test instrumenting codeflash trace decorator on multiple files.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create first test Python file + test_file_1_path = Path(temp_dir) / "module_a.py" + test_file_1_content = """ +def function_a(): + return "Function in module A" + +class ClassA: + def method_a(self): + return "Method in ClassA" +""" + test_file_1_path.write_text(test_file_1_content, encoding="utf-8") + + # Create second test Python file + test_file_2_path = Path(temp_dir) / "module_b.py" + test_file_2_content =""" +def function_b(): + return "Function in module B" + +class ClassB: + @staticmethod + def static_method_b(): + return "Static method in ClassB" +""" + test_file_2_path.write_text(test_file_2_content, encoding="utf-8") + + # Define functions to optimize + file_to_funcs_to_optimize = { + test_file_1_path: [ + FunctionToOptimize( + function_name="function_a", + file_path=test_file_1_path, + parents=[] + ) + ], + test_file_2_path: [ + FunctionToOptimize( + function_name="static_method_b", + file_path=test_file_2_path, + parents=[FunctionParent(name="ClassB", type="ClassDef")] + ) + ] + } + + # Execute the function being tested + instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) + + # Read the modified files + modified_content_1 = test_file_1_path.read_text(encoding="utf-8") + modified_content_2 = test_file_2_path.read_text(encoding="utf-8") + + # Define expected content for first file (with isort applied) + expected_content_1 = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace + + +@codeflash_trace +def function_a(): + return "Function in module A" + +class ClassA: + def method_a(self): + return "Method in ClassA" +""" + + # Define expected content for second file (with isort applied) + expected_content_2 = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace + + +def function_b(): + return "Function in module B" + +class ClassB: + @staticmethod + @codeflash_trace + def static_method_b(): + return "Static method in ClassB" +""" + + # Compare the modified content with expected content + assert modified_content_1.strip() == expected_content_1.strip() + assert modified_content_2.strip() == expected_content_2.strip() + + +def test_add_decorator_to_method_after_nested_class() -> None: + """Test adding decorator to a method that appears after a nested class definition.""" + code = """ +class OuterClass: + class NestedClass: + def nested_method(self): + return "Hello from nested class method" + + def target_method(self): + return "Hello from target method after nested class" +""" + + fto = FunctionToOptimize( + function_name="target_method", + file_path=Path("dummy_path.py"), + parents=[FunctionParent(name="OuterClass", type="ClassDef")] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +class OuterClass: + class NestedClass: + def nested_method(self): + return "Hello from nested class method" + + @codeflash_trace + def target_method(self): + return "Hello from target method after nested class" +""" + + assert modified_code.strip() == expected_code.strip() + + +def test_add_decorator_to_function_after_nested_function() -> None: + """Test adding decorator to a function that appears after a function with a nested function.""" + code = """ +def function_with_nested(): + def inner_function(): + return "Hello from inner function" + + return inner_function() + +def target_function(): + return "Hello from target function after nested function" +""" + + fto = FunctionToOptimize( + function_name="target_function", + file_path=Path("dummy_path.py"), + parents=[] + ) + + modified_code = add_codeflash_decorator_to_code( + code=code, + functions_to_optimize=[fto] + ) + + expected_code = """ +from codeflash.benchmarking.codeflash_trace import codeflash_trace +def function_with_nested(): + def inner_function(): + return "Hello from inner function" + + return inner_function() + +@codeflash_trace +def target_function(): + return "Hello from target function after nested function" +""" + + assert modified_code.strip() == expected_code.strip() \ No newline at end of file diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py new file mode 100644 index 000000000..3d2f21b66 --- /dev/null +++ b/tests/test_pickle_patcher.py @@ -0,0 +1,513 @@ +import os +import pickle +import shutil +import socket +import sqlite3 +from argparse import Namespace +from pathlib import Path + +import dill +import pytest + +from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin +from codeflash.benchmarking.replay_test import generate_replay_test +from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest +from codeflash.benchmarking.utils import validate_and_format_benchmark_table +from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType +from codeflash.optimization.optimizer import Optimizer +from codeflash.verification.equivalence import compare_test_results + +try: + import sqlalchemy + from sqlalchemy import Column, Integer, String, create_engine + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import Session + + HAS_SQLALCHEMY = True +except ImportError: + HAS_SQLALCHEMY = False + +from codeflash.picklepatch.pickle_patcher import PicklePatcher +from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePlaceholderAccessError + + +def test_picklepatch_simple_nested(): + """Test that a simple nested data structure pickles and unpickles correctly. + """ + original_data = { + "numbers": [1, 2, 3], + "nested_dict": {"key": "value", "another": 42}, + } + + dumped = PicklePatcher.dumps(original_data) + reloaded = PicklePatcher.loads(dumped) + + assert reloaded == original_data + # Everything was pickleable, so no placeholders should appear. + + +def test_picklepatch_with_socket(): + """Test that a data structure containing a raw socket is replaced by + PicklePlaceholder rather than raising an error. + """ + # Create a pair of connected sockets instead of a single socket + sock1, sock2 = socket.socketpair() + + data_with_socket = { + "safe_value": 123, + "raw_socket": sock1, + } + + # Send a message through sock1, which can be received by sock2 + sock1.send(b"Hello, world!") + received = sock2.recv(1024) + assert received == b"Hello, world!" + # Pickle the data structure containing the socket + dumped = PicklePatcher.dumps(data_with_socket) + reloaded = PicklePatcher.loads(dumped) + + # We expect "raw_socket" to be replaced by a placeholder + assert isinstance(reloaded, dict) + assert reloaded["safe_value"] == 123 + assert isinstance(reloaded["raw_socket"], PicklePlaceholder) + + # Attempting to use or access attributes => AttributeError + # (not RuntimeError as in original tests, our implementation uses AttributeError) + with pytest.raises(PicklePlaceholderAccessError): + reloaded["raw_socket"].recv(1024) + + # Clean up by closing both sockets + sock1.close() + sock2.close() + + +def test_picklepatch_deeply_nested(): + """Test that deep nesting with unpicklable objects works correctly. + """ + # Create a deeply nested structure with an unpicklable object + deep_nested = { + "level1": { + "level2": { + "level3": { + "normal": "value", + "socket": socket.socket(socket.AF_INET, socket.SOCK_STREAM) + } + } + } + } + + dumped = PicklePatcher.dumps(deep_nested) + reloaded = PicklePatcher.loads(dumped) + + # We should be able to access the normal value + assert reloaded["level1"]["level2"]["level3"]["normal"] == "value" + + # The socket should be replaced with a placeholder + assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder) + +def test_picklepatch_class_with_unpicklable_attr(): + """Test that a class with an unpicklable attribute works correctly. + """ + class TestClass: + def __init__(self): + self.normal = "normal value" + self.unpicklable = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + obj = TestClass() + + dumped = PicklePatcher.dumps(obj) + reloaded = PicklePatcher.loads(dumped) + + # Normal attribute should be preserved + assert reloaded.normal == "normal value" + + # Unpicklable attribute should be replaced with a placeholder + assert isinstance(reloaded.unpicklable, PicklePlaceholder) + + + + +def test_picklepatch_with_database_connection(): + """Test that a data structure containing a database connection is replaced + by PicklePlaceholder rather than raising an error. + """ + # SQLite connection - not pickleable + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + data_with_db = { + "description": "Database connection", + "connection": conn, + "cursor": cursor, + } + + dumped = PicklePatcher.dumps(data_with_db) + reloaded = PicklePatcher.loads(dumped) + + # Both connection and cursor should become placeholders + assert isinstance(reloaded, dict) + assert reloaded["description"] == "Database connection" + assert isinstance(reloaded["connection"], PicklePlaceholder) + assert isinstance(reloaded["cursor"], PicklePlaceholder) + + # Attempting to use attributes => AttributeError + with pytest.raises(PicklePlaceholderAccessError): + reloaded["connection"].execute("SELECT 1") + + +def test_picklepatch_with_generator(): + """Test that a data structure containing a generator is replaced by + PicklePlaceholder rather than raising an error. + """ + + def simple_generator(): + yield 1 + yield 2 + yield 3 + + # Create a generator + gen = simple_generator() + + # Put it in a data structure + data_with_generator = { + "description": "Contains a generator", + "generator": gen, + "normal_list": [1, 2, 3] + } + + dumped = PicklePatcher.dumps(data_with_generator) + reloaded = PicklePatcher.loads(dumped) + + # Generator should be replaced with a placeholder + assert isinstance(reloaded, dict) + assert reloaded["description"] == "Contains a generator" + assert reloaded["normal_list"] == [1, 2, 3] + assert isinstance(reloaded["generator"], PicklePlaceholder) + + # Attempting to use the generator => AttributeError + with pytest.raises(TypeError): + next(reloaded["generator"]) + + # Attempting to call methods on the generator => AttributeError + with pytest.raises(PicklePlaceholderAccessError): + reloaded["generator"].send(None) + + +def test_picklepatch_loads_standard_pickle(): + """Test that PicklePatcher.loads can correctly load data that was pickled + using the standard pickle module. + """ + # Create a simple data structure + original_data = { + "numbers": [1, 2, 3], + "nested_dict": {"key": "value", "another": 42}, + "tuple": (1, "two", 3.0), + } + + # Pickle it with standard pickle + pickled_data = pickle.dumps(original_data) + + # Load with PicklePatcher + reloaded = PicklePatcher.loads(pickled_data) + + # Verify the data is correctly loaded + assert reloaded == original_data + assert isinstance(reloaded, dict) + assert reloaded["numbers"] == [1, 2, 3] + assert reloaded["nested_dict"]["key"] == "value" + assert reloaded["tuple"] == (1, "two", 3.0) + + +def test_picklepatch_loads_dill_pickle(): + """Test that PicklePatcher.loads can correctly load data that was pickled + using the dill module, which can pickle more complex objects than the + standard pickle module. + """ + # Create a more complex data structure that includes a lambda function + # which dill can handle but standard pickle cannot + original_data = { + "numbers": [1, 2, 3], + "function": lambda x: x * 2, + "nested": { + "another_function": lambda y: y ** 2 + } + } + + # Pickle it with dill + dilled_data = dill.dumps(original_data) + + # Load with PicklePatcher + reloaded = PicklePatcher.loads(dilled_data) + + # Verify the data structure + assert isinstance(reloaded, dict) + assert reloaded["numbers"] == [1, 2, 3] + + # Test that the functions actually work + assert reloaded["function"](5) == 10 + assert reloaded["nested"]["another_function"](4) == 16 + +def test_run_and_parse_picklepatch() -> None: + """Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests. + + The first example has an argument (an object containing a socket) that is not pickleable However, the socket attributs is not used, so we are able to compare the test results with the optimized test results. + Here, we are simply 'ignoring' the unused unpickleable object. + + The second example also has an argument (an object containing socket) that is not pickleable. The socket attribute is used, which results in an error thrown by the PicklePlaceholder object. + Both the original and optimized results should error out in this case, but this should be flagged as incorrect behavior when comparing test results, + since we were not able to reuse the unpickleable object in the replay test. + """ + # Init paths + project_root = Path(__file__).parent.parent.resolve() + tests_root = project_root / "code_to_optimize" / "tests" / "pytest" + benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test" + replay_tests_dir = benchmarks_root / "codeflash_replay_tests" + output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve() + fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve() + fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve() + original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8") + original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8") + # Trace benchmarks + trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file) + assert output_file.exists() + try: + # Check contents + conn = sqlite3.connect(output_file.as_posix()) + cursor = conn.cursor() + + cursor.execute( + "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name") + function_calls = cursor.fetchall() + + # Assert the length of function calls + assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" + function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) + total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results + + test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0] + assert total_time > 0.0 + assert function_time > 0.0 + assert percent > 0.0 + + test_name, total_time, function_time, percent = \ + function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0] + assert total_time > 0.0 + assert function_time > 0.0 + assert percent > 0.0 + + bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix() + bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix() + # Expected function calls + expected_calls = [ + ("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket", + f"{bubble_sort_unused_socket_path}", + "test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12), + ("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket", + f"{bubble_sort_used_socket_path}", + "test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20) + ] + for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)): + assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name" + assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name" + assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name" + assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path" + assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name" + assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path" + assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number" + conn.close() + + # Generate replay test + generate_replay_test(output_file, replay_tests_dir) + replay_test_path = replay_tests_dir / Path( + "test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py") + replay_test_perf_path = replay_tests_dir / Path( + "test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py") + assert replay_test_path.exists() + original_replay_test_code = replay_test_path.read_text("utf-8") + + # Instrument the replay test + func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path)) + original_cwd = Path.cwd() + run_cwd = project_root + os.chdir(run_cwd) + success, new_test = inject_profiling_into_existing_test( + replay_test_path, + [CodePosition(17, 15)], + func, + project_root, + "pytest", + mode=TestingMode.BEHAVIOR, + ) + os.chdir(original_cwd) + assert success + assert new_test is not None + replay_test_path.write_text(new_test) + + opt = Optimizer( + Namespace( + project_root=project_root, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root, + ) + ) + + # Run the replay test for the original code that does not use the socket + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_type = TestType.REPLAY_TEST + replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket" + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=replay_test_path, + test_type=test_type, + original_file_path=replay_test_path, + benchmarking_file_path=replay_test_perf_path, + tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)], + ) + ] + ) + test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=1.0, + ) + assert len(test_results_unused_socket) == 1 + assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" + assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket" + assert test_results_unused_socket.test_results[0].did_pass == True + + # Replace with optimized candidate + fto_unused_socket_path.write_text(""" +from codeflash.benchmarking.codeflash_trace import codeflash_trace + +@codeflash_trace +def bubble_sort_with_unused_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + return sorted(numbers) +""") + # Run optimized code for unused socket + optimized_test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=1.0, + ) + assert len(optimized_test_results_unused_socket) == 1 + verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) + assert verification_result is True + + # Remove the previous instrumentation + replay_test_path.write_text(original_replay_test_code) + # Instrument the replay test + func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path)) + success, new_test = inject_profiling_into_existing_test( + replay_test_path, + [CodePosition(23,15)], + func, + project_root, + "pytest", + mode=TestingMode.BEHAVIOR, + ) + os.chdir(original_cwd) + assert success + assert new_test is not None + replay_test_path.write_text(new_test) + + # Run test for original function code that uses the socket. This should fail, as the PicklePlaceholder is accessed. + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_type = TestType.REPLAY_TEST + func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], + file_path=Path(fto_used_socket_path)) + replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=replay_test_path, + test_type=test_type, + original_file_path=replay_test_path, + benchmarking_file_path=replay_test_perf_path, + tests_in_file=[ + TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, + test_type=test_type)], + ) + ] + ) + test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=1.0, + ) + assert len(test_results_used_socket) == 1 + assert test_results_used_socket.test_results[ + 0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" + assert test_results_used_socket.test_results[ + 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + assert test_results_used_socket.test_results[0].did_pass is False + print("test results used socket") + print(test_results_used_socket) + # Replace with optimized candidate + fto_used_socket_path.write_text(""" +from codeflash.benchmarking.codeflash_trace import codeflash_trace + +@codeflash_trace +def bubble_sort_with_used_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + socket = data_container.get('socket') + socket.send("Hello from the optimized function!") + return sorted(numbers) + """) + + # Run test for optimized function code that uses the socket. This should fail, as the PicklePlaceholder is accessed. + optimized_test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=1.0, + ) + assert len(test_results_used_socket) == 1 + assert test_results_used_socket.test_results[ + 0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" + assert test_results_used_socket.test_results[ + 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + assert test_results_used_socket.test_results[0].did_pass is False + + # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. + assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False + + finally: + # cleanup + output_file.unlink(missing_ok=True) + shutil.rmtree(replay_tests_dir, ignore_errors=True) + fto_unused_socket_path.write_text(original_fto_unused_socket_code) + fto_used_socket_path.write_text(original_fto_used_socket_code) + diff --git a/tests/test_trace_benchmarks.py b/tests/test_trace_benchmarks.py new file mode 100644 index 000000000..72c0267a8 --- /dev/null +++ b/tests/test_trace_benchmarks.py @@ -0,0 +1,288 @@ +import multiprocessing +import shutil +import sqlite3 +from pathlib import Path + +import pytest + +from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin +from codeflash.benchmarking.replay_test import generate_replay_test +from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest +from codeflash.benchmarking.utils import validate_and_format_benchmark_table + + +def test_trace_benchmarks() -> None: + # Test the trace_benchmarks function + project_root = Path(__file__).parent.parent / "code_to_optimize" + benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test" + replay_tests_dir = benchmarks_root / "codeflash_replay_tests" + tests_root = project_root / "tests" + output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve() + trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file) + assert output_file.exists() + try: + # check contents of trace file + # connect to database + conn = sqlite3.connect(output_file.as_posix()) + cursor = conn.cursor() + + # Get the count of records + # Get all records + cursor.execute( + "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name") + function_calls = cursor.fetchall() + + # Assert the length of function calls + assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}" + + bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix() + process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix() + # Expected function calls + expected_calls = [ + ("sorter", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_class_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 17), + + ("sort_class", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_class_sort2", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 20), + + ("sort_static", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_class_sort3", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 23), + + ("__init__", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_class_sort4", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 26), + + ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 7), + + ("compute_and_sort", "", "code_to_optimize.process_and_bubble_sort_codeflash_trace", + f"{process_and_bubble_sort_path}", + "test_compute_and_sort", "tests.pytest.benchmarks_test.test_process_and_sort_example", 4), + + ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8), + + ("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5), + ] + for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)): + assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name" + assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name" + assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name" + assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path" + assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name" + assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path" + assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number" + # Close connection + conn.close() + generate_replay_test(output_file, replay_tests_dir) + test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py") + assert test_class_sort_path.exists() + test_class_sort_code = f""" +from code_to_optimize.bubble_sort_codeflash_trace import \\ + Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter +from code_to_optimize.bubble_sort_codeflash_trace import \\ + sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter +from codeflash.benchmarking.replay_test import get_next_arg_and_return +from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle + +functions = ['sort_class', 'sort_static', 'sorter'] +trace_file_path = r"{output_file.as_posix()}" + +def test_code_to_optimize_bubble_sort_codeflash_trace_sorter(): + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_sort", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs) + +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter(): + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort", function_name="sorter", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + function_name = "sorter" + if not args: + raise ValueError("No arguments provided for the method.") + if function_name == "__init__": + ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs) + else: + instance = args[0] # self + ret = instance.sorter(*args[1:], **kwargs) + +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class(): + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort2", function_name="sort_class", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + if not args: + raise ValueError("No arguments provided for the method.") + ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_class(*args[1:], **kwargs) + +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_static(): + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort3", function_name="sort_static", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_static(*args, **kwargs) + +def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__(): + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort4", function_name="__init__", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + function_name = "__init__" + if not args: + raise ValueError("No arguments provided for the method.") + if function_name == "__init__": + ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs) + else: + instance = args[0] # self + ret = instance(*args[1:], **kwargs) + +""" + assert test_class_sort_path.read_text("utf-8").strip()==test_class_sort_code.strip() + + test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py") + assert test_sort_path.exists() + test_sort_code = f""" +from code_to_optimize.bubble_sort_codeflash_trace import \\ + sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter +from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\ + compute_and_sort as \\ + code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort +from codeflash.benchmarking.replay_test import get_next_arg_and_return +from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle + +functions = ['compute_and_sort', 'sorter'] +trace_file_path = r"{output_file}" + +def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(): + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + ret = code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(*args, **kwargs) + +def test_code_to_optimize_bubble_sort_codeflash_trace_sorter(): + for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_no_func", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100): + args = pickle.loads(args_pkl) + kwargs = pickle.loads(kwargs_pkl) + ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs) + +""" + assert test_sort_path.read_text("utf-8").strip()==test_sort_code.strip() + finally: + # cleanup + output_file.unlink(missing_ok=True) + shutil.rmtree(replay_tests_dir) + +# Skip the test in CI as the machine may not be multithreaded +@pytest.mark.ci_skip +def test_trace_multithreaded_benchmark() -> None: + project_root = Path(__file__).parent.parent / "code_to_optimize" + benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_multithread" + tests_root = project_root / "tests" + output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve() + trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file) + assert output_file.exists() + try: + # check contents of trace file + # connect to database + conn = sqlite3.connect(output_file.as_posix()) + cursor = conn.cursor() + + # Get the count of records + # Get all records + cursor.execute( + "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name") + function_calls = cursor.fetchall() + + # Assert the length of function calls + assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}" + function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) + total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results + + test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0] + assert total_time > 0.0 + assert function_time > 0.0 + assert percent > 0.0 + + bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix() + # Expected function calls + expected_calls = [ + ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_benchmark_sort", "tests.pytest.benchmarks_multithread.test_multithread_sort", 4), + ] + for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)): + assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name" + assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name" + assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name" + assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path" + assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name" + assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path" + assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number" + # Close connection + conn.close() + + finally: + # cleanup + output_file.unlink(missing_ok=True) + +def test_trace_benchmark_decorator() -> None: + project_root = Path(__file__).parent.parent / "code_to_optimize" + benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator" + tests_root = project_root / "tests" + output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve() + trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file) + assert output_file.exists() + try: + # check contents of trace file + # connect to database + conn = sqlite3.connect(output_file.as_posix()) + cursor = conn.cursor() + + # Get the count of records + # Get all records + cursor.execute( + "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name") + function_calls = cursor.fetchall() + + # Assert the length of function calls + assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" + function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) + total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results + + test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0] + assert total_time > 0.0 + assert function_time > 0.0 + assert percent > 0.0 + + bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix() + # Expected function calls + expected_calls = [ + ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5), + ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11), + ] + for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)): + assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name" + assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name" + assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name" + assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path" + assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name" + assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path" + # Close connection + conn.close() + + finally: + # cleanup + output_file.unlink(missing_ok=True) diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 73556928e..8c3bc35c8 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -17,7 +17,19 @@ def test_unit_test_discovery_pytest(): ) tests = discover_unit_tests(test_config) assert len(tests) > 0 - # print(tests) + + +def test_benchmark_test_discovery_pytest(): + project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize" + tests_path = project_path / "tests" / "pytest" / "benchmarks" + test_config = TestConfig( + tests_root=tests_path, + project_root_path=project_path, + test_framework="pytest", + tests_project_rootdir=tests_path.parent, + ) + tests = discover_unit_tests(test_config) + assert len(tests) == 1 # Should not discover benchmark tests def test_unit_test_discovery_unittest():