diff --git a/.github/workflows/e2e-async.yaml b/.github/workflows/e2e-async.yaml new file mode 100644 index 000000000..e7d08091c --- /dev/null +++ b/.github/workflows/e2e-async.yaml @@ -0,0 +1,69 @@ +name: E2E - Async + +on: + pull_request: + paths: + - '**' # Trigger for all paths + + workflow_dispatch: + +jobs: + async-optimization: + # Dynamically determine if environment is needed only when workflow files change and contributor is external + environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }} + + runs-on: ubuntu-latest + env: + CODEFLASH_AIS_SERVER: prod + POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} + CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} + COLUMNS: 110 + MAX_RETRIES: 3 + RETRY_DELAY: 5 + EXPECTED_IMPROVEMENT_PCT: 10 + CODEFLASH_END_TO_END: 1 + steps: + - name: 🛎️ Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Validate PR + run: | + # Check for any workflow changes + if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then + echo "⚠️ Workflow changes detected." + + # Get the PR author + AUTHOR="${{ github.event.pull_request.user.login }}" + echo "PR Author: $AUTHOR" + + # Allowlist check + if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then + echo "✅ Authorized user ($AUTHOR). Proceeding." + elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then + echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding." + else + echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting." + exit 1 + fi + else + echo "✅ No workflow file changes detected. Proceeding." + fi + + - name: Set up Python 3.11 for CLI + uses: astral-sh/setup-uv@v5 + with: + python-version: 3.11.6 + + - name: Install dependencies (CLI) + run: | + uv sync + + - name: Run Codeflash to optimize async code + id: optimize_async_code + run: | + uv run python tests/scripts/end_to_end_test_async.py \ No newline at end of file diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml deleted file mode 100644 index bc0a20ae8..000000000 --- a/.github/workflows/pre-commit.yaml +++ /dev/null @@ -1,19 +0,0 @@ -name: Lint -on: - pull_request: - push: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - lint: - name: Run pre-commit hooks - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - - uses: pre-commit/action@v3.0.1 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9a9af7cd8..8d08a77f3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,9 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.7 + rev: v0.13.1 hooks: # Run the linter. - id: ruff-check + args: [ --config=pyproject.toml ] # Run the formatter. - id: ruff-format \ No newline at end of file diff --git a/code_to_optimize/code_directories/async_e2e/main.py b/code_to_optimize/code_directories/async_e2e/main.py index 4470cc969..317068a1c 100644 --- a/code_to_optimize/code_directories/async_e2e/main.py +++ b/code_to_optimize/code_directories/async_e2e/main.py @@ -1,4 +1,16 @@ import time -async def fake_api_call(delay, data): - time.sleep(0.0001) - return f"Processed: {data}" \ No newline at end of file +import asyncio + + +async def retry_with_backoff(func, max_retries=3): + if max_retries < 1: + raise ValueError("max_retries must be at least 1") + last_exception = None + for attempt in range(max_retries): + try: + return await func() + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + time.sleep(0.0001 * attempt) + raise last_exception diff --git a/codeflash.code-workspace b/codeflash.code-workspace index 5d86915cc..a07674eb7 100644 --- a/codeflash.code-workspace +++ b/codeflash.code-workspace @@ -70,7 +70,11 @@ "request": "launch", "program": "${workspaceFolder:codeflash}/codeflash/main.py", "args": [ - "--file", "src/async_examples/shocker.py", "--verbose" + "--file", + "src/async_examples/concurrency.py", + "--function", + "task", + "--verbose" ], "cwd": "${input:chooseCwd}", "console": "integratedTerminal", diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 3e24d5bac..79f4d5300 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -298,6 +298,9 @@ def get_new_explanation( # noqa: D417 annotated_tests: str, optimization_id: str, original_explanation: str, + original_throughput: str | None = None, + optimized_throughput: str | None = None, + throughput_improvement: str | None = None, ) -> str: """Optimize the given python code for performance by making a request to the Django endpoint. @@ -314,6 +317,9 @@ def get_new_explanation( # noqa: D417 - annotated_tests: str - test functions annotated with runtime - optimization_id: str - unique id of opt candidate - original_explanation: str - original_explanation generated for the opt candidate + - original_throughput: str | None - throughput for the baseline code (operations per second) + - optimized_throughput: str | None - throughput for the optimized code (operations per second) + - throughput_improvement: str | None - throughput improvement percentage Returns ------- @@ -333,6 +339,9 @@ def get_new_explanation( # noqa: D417 "optimization_id": optimization_id, "original_explanation": original_explanation, "dependency_code": dependency_code, + "original_throughput": original_throughput, + "optimized_throughput": optimized_throughput, + "throughput_improvement": throughput_improvement, } logger.info("Generating explanation") console.rule() diff --git a/codeflash/code_utils/codeflash_wrap_decorator.py b/codeflash/code_utils/codeflash_wrap_decorator.py index cb4da64a0..5dda746de 100644 --- a/codeflash/code_utils/codeflash_wrap_decorator.py +++ b/codeflash/code_utils/codeflash_wrap_decorator.py @@ -1,23 +1,17 @@ from __future__ import annotations import asyncio -import contextlib import gc -import inspect import os import sqlite3 -import time from enum import Enum from functools import wraps from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from typing import Any, Callable, TypeVar import dill as pickle -if TYPE_CHECKING: - from types import FrameType - class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py FUNCTION_CALL = ( @@ -36,175 +30,17 @@ def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_util return Path(get_run_tmp_file.tmpdir.name) / file_path -def _extract_class_name_tracer(frame_locals: dict[str, Any]) -> str | None: - try: - self_arg = frame_locals.get("self") - if self_arg is not None: - try: - return self_arg.__class__.__name__ - except (AttributeError, Exception): - cls_arg = frame_locals.get("cls") - if cls_arg is not None: - with contextlib.suppress(AttributeError, Exception): - return cls_arg.__name__ - else: - cls_arg = frame_locals.get("cls") - if cls_arg is not None: - with contextlib.suppress(AttributeError, Exception): - return cls_arg.__name__ - except Exception: - return None - return None - - -def _get_module_name_cf_tracer(frame: FrameType | None) -> str: - try: - test_module = inspect.getmodule(frame) - except Exception: - test_module = None - - if test_module is not None: - module_name = getattr(test_module, "__name__", None) - if module_name is not None: - return module_name - - if frame is not None: - return frame.f_globals.get("__name__", "unknown_module") - return "unknown_module" - - -def extract_test_context_from_frame() -> tuple[str, str | None, str]: - frame = inspect.currentframe() - # optimize? - try: - frames_info = [] - potential_tests = [] - - # First pass: collect all frame information - if frame is not None: - frame = frame.f_back - - while frame is not None: - try: - function_name = frame.f_code.co_name - filename = frame.f_code.co_filename - filename_path = Path(filename) - frame_locals = frame.f_locals - test_module_name = _get_module_name_cf_tracer(frame) - class_name = _extract_class_name_tracer(frame_locals) - - frames_info.append( - { - "function_name": function_name, - "filename_path": filename_path, - "frame_locals": frame_locals, - "test_module_name": test_module_name, - "class_name": class_name, - "frame": frame, - } - ) - - except Exception: # noqa: S112 - continue - - frame = frame.f_back - - # Second pass: analyze frames with full context - test_class_candidates = [] - for frame_info in frames_info: - function_name = frame_info["function_name"] - filename_path = frame_info["filename_path"] - frame_locals = frame_info["frame_locals"] - test_module_name = frame_info["test_module_name"] - class_name = frame_info["class_name"] - frame_obj = frame_info["frame"] - - # Keep track of test classes - if class_name and ( - class_name.startswith("Test") or class_name.endswith("Test") or "test" in class_name.lower() - ): - test_class_candidates.append((class_name, test_module_name)) - - # Now process frames again looking for test functions with full candidates list - # Collect all test functions to prioritize outer ones over nested ones - test_functions = [] - for frame_info in frames_info: - function_name = frame_info["function_name"] - filename_path = frame_info["filename_path"] - frame_locals = frame_info["frame_locals"] - test_module_name = frame_info["test_module_name"] - class_name = frame_info["class_name"] - frame_obj = frame_info["frame"] - - # Collect test functions - if function_name.startswith("test_"): - test_class_name = class_name - - # If no class found in current frame, check if we have any test class candidates - # Prefer the innermost (first) test class candidate which is more specific - if test_class_name is None and test_class_candidates: - test_class_name = test_class_candidates[0][0] - - test_functions.append((test_module_name, test_class_name, function_name)) - - # Prioritize test functions with class context, then innermost - if test_functions: - # First prefer test functions with class context - for test_func in test_functions: - if test_func[1] is not None: # has class_name - return test_func - # If no test function has class context, return the outermost (most likely the actual test method) - return test_functions[-1] - - # If no direct test functions found, look for other test patterns - for frame_info in frames_info: - function_name = frame_info["function_name"] - filename_path = frame_info["filename_path"] - frame_locals = frame_info["frame_locals"] - test_module_name = frame_info["test_module_name"] - class_name = frame_info["class_name"] - frame_obj = frame_info["frame"] - - # Test file/module detection - if ( - frame_obj.f_globals.get("__name__", "").startswith("test_") - or filename_path.stem.startswith("test_") - or "test" in filename_path.parts - ): - if class_name and ( - class_name.startswith("Test") or class_name.endswith("Test") or "test" in class_name.lower() - ): - potential_tests.append((test_module_name, class_name, function_name)) - elif "test" in test_module_name or filename_path.stem.startswith("test_"): - # For functions without class context, try to find the most recent test class - best_class = test_class_candidates[0][0] if test_class_candidates else None - potential_tests.append((test_module_name, best_class, function_name)) - - # Framework integration detection - if ( - ( - function_name in ["runTest", "_runTest", "run", "_testMethodName"] - or "pytest" in str(frame_obj.f_globals.get("__file__", "")) - or "unittest" in str(frame_obj.f_globals.get("__file__", "")) - ) - and class_name - and (class_name.startswith("Test") or "test" in class_name.lower()) - ): - test_method = function_name - if "self" in frame_locals: - with contextlib.suppress(AttributeError, TypeError): - test_method = getattr(frame_locals["self"], "_testMethodName", function_name) - potential_tests.append((test_module_name, class_name, test_method)) - - if potential_tests: - for test_module, test_class, test_func in potential_tests: - if test_func.startswith("test_"): - return test_module, test_class, test_func - return potential_tests[0] - - raise RuntimeError("No test function found in call stack") - finally: - del frame +def extract_test_context_from_env() -> tuple[str, str | None, str]: + test_module = os.environ["CODEFLASH_TEST_MODULE"] + test_class = os.environ.get("CODEFLASH_TEST_CLASS", None) + test_function = os.environ["CODEFLASH_TEST_FUNCTION"] + + if test_module and test_function: + return (test_module, test_class if test_class else None, test_function) + + raise RuntimeError( + "Test context environment variables not set - ensure tests are run through codeflash test runner" + ) def codeflash_behavior_async(func: F) -> F: @@ -212,9 +48,9 @@ def codeflash_behavior_async(func: F) -> F: async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 loop = asyncio.get_running_loop() function_name = func.__name__ - line_id = f"{func.__name__}_{func.__code__.co_firstlineno}" + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) - test_module_name, test_class_name, test_name = extract_test_context_from_frame() + test_module_name, test_class_name, test_name = extract_test_context_from_env() test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" @@ -288,10 +124,10 @@ def codeflash_performance_async(func: F) -> F: async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 loop = asyncio.get_running_loop() function_name = func.__name__ - line_id = f"{func.__name__}_{func.__code__.co_firstlineno}" + line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"] loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) - test_module_name, test_class_name, test_name = extract_test_context_from_frame() + test_module_name, test_class_name, test_name = extract_test_context_from_env() test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}" diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 50b4bce16..a8fc74733 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -11,3 +11,4 @@ MIN_TESTCASE_PASSED_THRESHOLD = 6 REPEAT_OPTIMIZATION_PROBABILITY = 0.1 DEFAULT_IMPORTANCE_THRESHOLD = 0.001 +MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index be75eac85..8eb671540 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -291,6 +291,139 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = return node +class AsyncCallInstrumenter(ast.NodeTransformer): + def __init__( + self, + function: FunctionToOptimize, + module_path: str, + test_framework: str, + call_positions: list[CodePosition], + mode: TestingMode = TestingMode.BEHAVIOR, + ) -> None: + self.mode = mode + self.function_object = function + self.class_name = None + self.only_function_name = function.function_name + self.module_path = module_path + self.test_framework = test_framework + self.call_positions = call_positions + self.did_instrument = False + # Track function call count per test function + self.async_call_counter: dict[str, int] = {} + if len(function.parents) == 1 and function.parents[0].type == "ClassDef": + self.class_name = function.top_level_parent_name + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + # Add timeout decorator for unittest test classes if needed + if self.test_framework == "unittest": + timeout_decorator = ast.Call( + func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), + args=[ast.Constant(value=15)], + keywords=[], + ) + for item in node.body: + if ( + isinstance(item, ast.FunctionDef) + and item.name.startswith("test_") + and not any( + isinstance(d, ast.Call) + and isinstance(d.func, ast.Name) + and d.func.id == "timeout_decorator.timeout" + for d in item.decorator_list + ) + ): + item.decorator_list.append(timeout_decorator) + return self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: + if not node.name.startswith("test_"): + return node + + return self._process_test_function(node) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + # Only process test functions + if not node.name.startswith("test_"): + return node + + return self._process_test_function(node) + + def _process_test_function( + self, node: ast.AsyncFunctionDef | ast.FunctionDef + ) -> ast.AsyncFunctionDef | ast.FunctionDef: + if self.test_framework == "unittest" and not any( + isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout" + for d in node.decorator_list + ): + timeout_decorator = ast.Call( + func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), + args=[ast.Constant(value=15)], + keywords=[], + ) + node.decorator_list.append(timeout_decorator) + + # Initialize counter for this test function + if node.name not in self.async_call_counter: + self.async_call_counter[node.name] = 0 + + new_body = [] + + for _i, stmt in enumerate(node.body): + transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name) + + if added_env_assignment: + current_call_index = self.async_call_counter[node.name] + self.async_call_counter[node.name] += 1 + + env_assignment = ast.Assign( + targets=[ + ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load() + ), + slice=ast.Constant(value="CODEFLASH_CURRENT_LINE_ID"), + ctx=ast.Store(), + ) + ], + value=ast.Constant(value=f"{current_call_index}"), + lineno=stmt.lineno if hasattr(stmt, "lineno") else 1, + ) + new_body.append(env_assignment) + self.did_instrument = True + + new_body.append(transformed_stmt) + + node.body = new_body + return node + + def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]: + for node in ast.walk(stmt): + if ( + isinstance(node, ast.Await) + and isinstance(node.value, ast.Call) + and self._is_target_call(node.value) + and self._call_in_positions(node.value) + ): + # Check if this call is in one of our target positions + return stmt, True # Return original statement but signal we added env var + + return stmt, False + + def _is_target_call(self, call_node: ast.Call) -> bool: + """Check if this call node is calling our target async function.""" + if isinstance(call_node.func, ast.Name): + return call_node.func.id == self.function_object.function_name + if isinstance(call_node.func, ast.Attribute): + return call_node.func.attr == self.function_object.function_name + return False + + def _call_in_positions(self, call_node: ast.Call) -> bool: + if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"): + return False + + return node_in_call_position(call_node, self.call_positions) + + class FunctionImportedAsVisitor(ast.NodeVisitor): """Checks if a function has been imported as an alias. We only care about the alias then. @@ -352,6 +485,44 @@ def instrument_source_module_with_async_decorators( return False, None +def inject_async_profiling_into_existing_test( + test_path: Path, + call_positions: list[CodePosition], + function_to_optimize: FunctionToOptimize, + tests_project_root: Path, + test_framework: str, + mode: TestingMode = TestingMode.BEHAVIOR, +) -> tuple[bool, str | None]: + """Inject profiling for async function calls by setting environment variables before each call.""" + with test_path.open(encoding="utf8") as f: + test_code = f.read() + + try: + tree = ast.parse(test_code) + except SyntaxError: + logger.exception(f"Syntax error in code in file - {test_path}") + return False, None + + test_module_path = module_name_from_file_path(test_path, tests_project_root) + import_visitor = FunctionImportedAsVisitor(function_to_optimize) + import_visitor.visit(tree) + func = import_visitor.imported_as + + async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode) + tree = async_instrumenter.visit(tree) + + if not async_instrumenter.did_instrument: + return False, None + + # Add necessary imports + new_imports = [ast.Import(names=[ast.alias(name="os")])] + if test_framework == "unittest": + new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")])) + + tree.body = [*new_imports, *tree.body] + return True, isort.code(ast.unparse(tree), float_to_top=True) + + def inject_profiling_into_existing_test( test_path: Path, call_positions: list[CodePosition], @@ -361,7 +532,9 @@ def inject_profiling_into_existing_test( mode: TestingMode = TestingMode.BEHAVIOR, ) -> tuple[bool, str | None]: if function_to_optimize.is_async: - return False, None + return inject_async_profiling_into_existing_test( + test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode + ) with test_path.open(encoding="utf8") as f: test_code = f.read() diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 8417148ef..e1e094661 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -100,6 +100,7 @@ class BestOptimization(BaseModel): winning_benchmarking_test_results: TestResults winning_replay_benchmarking_test_results: Optional[TestResults] = None line_profiler_test_results: dict + async_throughput: Optional[int] = None @dataclass(frozen=True) @@ -274,6 +275,7 @@ class OptimizedCandidateResult(BaseModel): replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None optimization_candidate_index: int total_candidate_timing: int + async_throughput: Optional[int] = None class GeneratedTests(BaseModel): @@ -380,6 +382,7 @@ class OriginalCodeBaseline(BaseModel): line_profile_results: dict runtime: int coverage_results: Optional[CoverageData] + async_throughput: Optional[int] = None class CoverageStatus(Enum): @@ -563,6 +566,7 @@ class TestResults(BaseModel): # noqa: PLW1641 # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] test_result_idx: dict[str, int] = {} + perf_stdout: Optional[str] = None def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 99f6d42f0..bec15fe69 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -77,14 +77,20 @@ TestType, ) from codeflash.result.create_pr import check_create_pr, existing_tests_source_for -from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic +from codeflash.result.critic import ( + coverage_critic, + performance_gain, + quantity_of_tests_critic, + speedup_critic, + throughput_gain, +) from codeflash.result.explanation import Explanation from codeflash.telemetry.posthog_cf import ph from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results -from codeflash.verification.parse_test_output import parse_test_results +from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests @@ -566,7 +572,11 @@ def determine_best_candidate( tree = Tree(f"Candidate #{candidate_index} - Runtime Information") benchmark_tree = None if speedup_critic( - candidate_result, original_code_baseline.runtime, best_runtime_until_now=None + candidate_result, + original_code_baseline.runtime, + best_runtime_until_now=None, + original_async_throughput=original_code_baseline.async_throughput, + best_throughput_until_now=None, ) and quantity_of_tests_critic(candidate_result): tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") @@ -577,6 +587,17 @@ def determine_best_candidate( ) tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + if ( + original_code_baseline.async_throughput is not None + and candidate_result.async_throughput is not None + ): + throughput_gain_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=candidate_result.async_throughput, + ) + tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions") + tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions") + tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%") line_profile_test_results = self.line_profiler_step( code_context=code_context, original_helper_code=original_helper_code, @@ -612,6 +633,7 @@ def determine_best_candidate( replay_performance_gain=replay_perf_gain if self.args.benchmark else None, winning_benchmarking_test_results=candidate_result.benchmarking_test_results, winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, + async_throughput=candidate_result.async_throughput, ) valid_optimizations.append(best_optimization) # queue corresponding refined optimization for best optimization @@ -636,6 +658,15 @@ def determine_best_candidate( ) tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + if ( + original_code_baseline.async_throughput is not None + and candidate_result.async_throughput is not None + ): + throughput_gain_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=candidate_result.async_throughput, + ) + tree.add(f"Throughput gain: {throughput_gain_value * 100:.1f}%") console.print(tree) if self.args.benchmark and benchmark_tree: console.print(benchmark_tree) @@ -674,6 +705,7 @@ def determine_best_candidate( replay_performance_gain=valid_opt.replay_performance_gain, winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results, winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results, + async_throughput=valid_opt.async_throughput, ) valid_candidates_with_shorter_code.append(new_best_opt) diff_lens_list.append( @@ -1176,6 +1208,8 @@ def find_and_process_best_optimization( function_name=function_to_optimize_qualified_name, file_path=self.function_to_optimize.file_path, benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, + original_async_throughput=original_code_baseline.async_throughput, + best_async_throughput=best_optimization.async_throughput, ) self.replace_function_and_helpers_with_optimized_code( @@ -1258,6 +1292,23 @@ def process_review( original_runtimes_all=original_runtime_by_test, optimized_runtimes_all=optimized_runtime_by_test, ) + original_throughput_str = None + optimized_throughput_str = None + throughput_improvement_str = None + + if ( + self.function_to_optimize.is_async + and original_code_baseline.async_throughput is not None + and best_optimization.async_throughput is not None + ): + original_throughput_str = f"{original_code_baseline.async_throughput} operations/second" + optimized_throughput_str = f"{best_optimization.async_throughput} operations/second" + throughput_improvement_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=best_optimization.async_throughput, + ) + throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%" + new_explanation_raw_str = self.aiservice_client.get_new_explanation( source_code=code_context.read_writable_code.flat, dependency_code=code_context.read_only_context_code, @@ -1271,6 +1322,9 @@ def process_review( annotated_tests=generated_tests_str, optimization_id=best_optimization.candidate.optimization_id, original_explanation=best_optimization.candidate.explanation, + original_throughput=original_throughput_str, + optimized_throughput=optimized_throughput_str, + throughput_improvement=throughput_improvement_str, ) new_explanation = Explanation( raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message, @@ -1281,6 +1335,8 @@ def process_review( function_name=explanation.function_name, file_path=explanation.file_path, benchmark_details=explanation.benchmark_details, + original_async_throughput=explanation.original_async_throughput, + best_async_throughput=explanation.best_async_throughput, ) self.log_successful_optimization(new_explanation, generated_tests, exp_type) @@ -1409,6 +1465,7 @@ def establish_original_code_baseline( return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.") if not coverage_critic(coverage_results, self.args.test_framework): return Failure("The threshold for test coverage was not met.") + if test_framework == "pytest": line_profile_results = self.line_profiler_step( code_context=code_context, original_helper_code=original_helper_code, candidate_index=0 @@ -1502,6 +1559,14 @@ def establish_original_code_baseline( console.rule() logger.debug(f"Total original code runtime (ns): {total_timing}") + async_throughput = None + if self.function_to_optimize.is_async: + async_throughput = calculate_function_throughput_from_test_results( + benchmarking_results, self.function_to_optimize.function_name + ) + logger.debug(f"Original async function throughput: {async_throughput} calls/second") + console.rule() + if self.args.benchmark: replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks( self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root @@ -1517,6 +1582,7 @@ def establish_original_code_baseline( runtime=total_timing, coverage_results=coverage_results, line_profile_results=line_profile_results, + async_throughput=async_throughput, ), functions_to_remove, ) @@ -1659,6 +1725,14 @@ def run_optimized_candidate( console.rule() logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") + + candidate_async_throughput = None + if self.function_to_optimize.is_async: + candidate_async_throughput = calculate_function_throughput_from_test_results( + candidate_benchmarking_results, self.function_to_optimize.function_name + ) + logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second") + if self.args.benchmark: candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root @@ -1678,6 +1752,7 @@ def run_optimized_candidate( else None, optimization_candidate_index=optimization_candidate_index, total_candidate_timing=total_candidate_timing, + async_throughput=candidate_async_throughput, ) ) @@ -1769,8 +1844,10 @@ def run_and_parse_tests( coverage_database_file=coverage_database_file, coverage_config_file=coverage_config_file, ) - else: - results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file) + if testing_type == TestingMode.PERFORMANCE: + results.perf_stdout = run_result.stdout + return results, coverage_results + results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file) return results, coverage_results def submit_test_generation_tasks( diff --git a/codeflash/result/critic.py b/codeflash/result/critic.py index 8aea5ebae..d0ff62176 100644 --- a/codeflash/result/critic.py +++ b/codeflash/result/critic.py @@ -8,6 +8,7 @@ COVERAGE_THRESHOLD, MIN_IMPROVEMENT_THRESHOLD, MIN_TESTCASE_PASSED_THRESHOLD, + MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD, ) from codeflash.models.models import TestType @@ -25,20 +26,41 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns +def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> float: + """Calculate the throughput gain of an optimized code over the original code. + + This value multiplied by 100 gives the percentage improvement in throughput. + For throughput, higher values are better (more executions per time period). + """ + if original_throughput == 0: + return 0.0 + return (optimized_throughput - original_throughput) / original_throughput + + def speedup_critic( candidate_result: OptimizedCandidateResult, original_code_runtime: int, best_runtime_until_now: int | None, *, disable_gh_action_noise: bool = False, + original_async_throughput: int | None = None, + best_throughput_until_now: int | None = None, ) -> bool: """Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user. - Ensure that the optimization is actually faster than the original code, above the noise floor. - The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD - when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime. - The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there. + Evaluates both runtime performance and async throughput improvements. + + For runtime performance: + - Ensures the optimization is actually faster than the original code, above the noise floor. + - The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD + when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime. + - The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance. + + For async throughput (when available): + - Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD + - Throughput improvements complement runtime improvements for async functions """ + # Runtime performance evaluation noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD if not disable_gh_action_noise and env_utils.is_ci(): noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode @@ -46,10 +68,31 @@ def speedup_critic( perf_gain = performance_gain( original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime ) - if best_runtime_until_now is None: - # collect all optimizations with this - return bool(perf_gain > noise_floor) - return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now) + runtime_improved = perf_gain > noise_floor + + # Check runtime comparison with best so far + runtime_is_best = best_runtime_until_now is None or candidate_result.best_test_runtime < best_runtime_until_now + + throughput_improved = True # Default to True if no throughput data + throughput_is_best = True # Default to True if no throughput data + + if original_async_throughput is not None and candidate_result.async_throughput is not None: + if original_async_throughput > 0: + throughput_gain_value = throughput_gain( + original_throughput=original_async_throughput, optimized_throughput=candidate_result.async_throughput + ) + throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD + + throughput_is_best = ( + best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now + ) + + if original_async_throughput is not None and candidate_result.async_throughput is not None: + # When throughput data is available, accept if EITHER throughput OR runtime improves significantly + throughput_acceptance = throughput_improved and throughput_is_best + runtime_acceptance = runtime_improved and runtime_is_best + return throughput_acceptance or runtime_acceptance + return runtime_improved and runtime_is_best def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool: diff --git a/codeflash/result/explanation.py b/codeflash/result/explanation.py index eb12beeb6..9fa5d02a5 100644 --- a/codeflash/result/explanation.py +++ b/codeflash/result/explanation.py @@ -11,6 +11,7 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.models.models import BenchmarkDetail, TestResults +from codeflash.result.critic import performance_gain, throughput_gain @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) @@ -23,9 +24,29 @@ class Explanation: function_name: str file_path: Path benchmark_details: Optional[list[BenchmarkDetail]] = None + original_async_throughput: Optional[int] = None + best_async_throughput: Optional[int] = None @property def perf_improvement_line(self) -> str: + runtime_improvement = self.speedup + + if ( + self.original_async_throughput is not None + and self.best_async_throughput is not None + and self.original_async_throughput > 0 + ): + throughput_improvement = throughput_gain( + original_throughput=self.original_async_throughput, + optimized_throughput=self.best_async_throughput, + ) + + # Use throughput metrics if throughput improvement is better or runtime got worse + if throughput_improvement > runtime_improvement or runtime_improvement <= 0: + throughput_pct = f"{throughput_improvement * 100:,.0f}%" + throughput_x = f"{throughput_improvement + 1:,.2f}x" + return f"{throughput_pct} improvement ({throughput_x} faster)." + return f"{self.speedup_pct} improvement ({self.speedup_x} faster)." @property @@ -45,6 +66,24 @@ def to_console_string(self) -> str: # TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts original_runtime_human = humanize_runtime(self.original_runtime_ns) best_runtime_human = humanize_runtime(self.best_runtime_ns) + + # Determine if we're showing throughput or runtime improvements + runtime_improvement = self.speedup + is_using_throughput_metric = False + + if ( + self.original_async_throughput is not None + and self.best_async_throughput is not None + and self.original_async_throughput > 0 + ): + throughput_improvement = throughput_gain( + original_throughput=self.original_async_throughput, + optimized_throughput=self.best_async_throughput, + ) + + if throughput_improvement > runtime_improvement or runtime_improvement <= 0: + is_using_throughput_metric = True + benchmark_info = "" if self.benchmark_details: @@ -85,10 +124,18 @@ def to_console_string(self) -> str: console.print(table) benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy + if is_using_throughput_metric: + performance_description = ( + f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second " + f"(runtime: {original_runtime_human} → {best_runtime_human})\n\n" + ) + else: + performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" + return ( f"Optimized {self.function_name} in {self.file_path}\n" f"{self.perf_improvement_line}\n" - f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" + + performance_description + (benchmark_info if benchmark_info else "") + self.raw_explanation_message + " \n\n" diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 4af1eec50..3b19d94c8 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -40,6 +40,30 @@ def parse_func(file_path: Path) -> XMLParser: matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!") +start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!") +end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!") + + +def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int: + """Calculate function throughput from TestResults by extracting performance stdout. + + A completed execution is defined as having both a start tag and matching end tag from performance wrappers. + Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$! + End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######! + """ + start_matches = start_pattern.findall(test_results.perf_stdout or "") + end_matches = end_pattern.findall(test_results.perf_stdout or "") + + end_matches_truncated = [end_match[:5] for end_match in end_matches] + end_matches_set = set(end_matches_truncated) + + function_throughput = 0 + for start_match in start_matches: + if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name: + function_throughput += 1 + return function_throughput + + def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults: test_results = TestResults() if not file_location.exists(): diff --git a/codeflash/verification/pytest_plugin.py b/codeflash/verification/pytest_plugin.py index 85cd4d13c..4dbdcf762 100644 --- a/codeflash/verification/pytest_plugin.py +++ b/codeflash/verification/pytest_plugin.py @@ -450,3 +450,25 @@ def make_progress_id(i: int, n: int = count) -> str: metafunc.parametrize( "__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope ) + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_setup(self, item: pytest.Item) -> None: + test_module_name = item.module.__name__ if item.module else "unknown_module" + + test_class_name = None + if item.cls: + test_class_name = item.cls.__name__ + + test_function_name = item.name + if "[" in test_function_name: + test_function_name = test_function_name.split("[", 1)[0] + + os.environ["CODEFLASH_TEST_MODULE"] = test_module_name + os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or "" + os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name + + @pytest.hookimpl(trylast=True) + def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002 + """Clean up test context environment variables after each test.""" + for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]: + os.environ.pop(var, None) diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 85e347641..7cb2d2c3f 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -98,7 +98,7 @@ def run_behavioral_tests( coverage_cmd.extend(shlex.split(pytest_cmd, posix=IS_POSIX)[1:]) blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS if plugin != "cov"] - + logger.info(f"{' '.join(coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files)}") results = execute_test_subprocess( coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files, cwd=cwd, diff --git a/tests/scripts/end_to_end_test_async.py b/tests/scripts/end_to_end_test_async.py index f9ef1d806..5aed8f8ca 100644 --- a/tests/scripts/end_to_end_test_async.py +++ b/tests/scripts/end_to_end_test_async.py @@ -6,14 +6,14 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( - file_path="workload.py", - expected_unit_tests=1, + file_path="main.py", + expected_unit_tests=0, min_improvement_x=0.1, coverage_expectations=[ CoverageExpectation( - function_name="process_data_list", + function_name="retry_with_backoff", expected_coverage=100.0, - expected_lines=[5, 7, 8, 9, 10, 12], + expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], ) ], ) diff --git a/tests/test_async_run_and_parse_tests.py b/tests/test_async_run_and_parse_tests.py index 1c5ddae63..b83be5c5a 100644 --- a/tests/test_async_run_and_parse_tests.py +++ b/tests/test_async_run_and_parse_tests.py @@ -77,6 +77,10 @@ async def test_async_sort(): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_bubble_sort_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_sort" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST # Create function optimizer and set up test files @@ -107,7 +111,7 @@ async def test_async_sort(): results_list = test_results.test_results assert results_list[0].id.function_getting_tested == "async_sorter" - assert results_list[0].id.test_class_name == "PytestPluginManager" + assert results_list[0].id.test_class_name is None assert results_list[0].id.test_function_name == "test_async_sort" assert results_list[0].did_pass assert results_list[0].runtime is None or results_list[0].runtime >= 0 @@ -197,6 +201,10 @@ async def test_async_class_sort(): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_class_bubble_sort_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_class_sort" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST func_optimizer = opt.create_function_optimizer(func) @@ -233,7 +241,7 @@ async def test_async_class_sort(): assert sorter_result.id.function_getting_tested == "sorter" - assert sorter_result.id.test_class_name == "PytestPluginManager" + assert sorter_result.id.test_class_name is None assert sorter_result.id.test_function_name == "test_async_class_sort" assert sorter_result.did_pass assert sorter_result.runtime is None or sorter_result.runtime >= 0 @@ -306,6 +314,10 @@ async def test_async_perf(): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_perf_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_perf" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST func_optimizer = opt.create_function_optimizer(func) @@ -459,6 +471,10 @@ async def async_error_function(lst): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_error_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_error" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST func_optimizer = opt.create_function_optimizer(func) @@ -553,6 +569,10 @@ async def test_async_multi(): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "3" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_multi_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_multi" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST func_optimizer = opt.create_function_optimizer(func) @@ -664,6 +684,10 @@ async def test_async_edge_cases(): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_edge_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_edge_cases" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST func_optimizer = opt.create_function_optimizer(func) @@ -796,6 +820,10 @@ def test_sync_sort(): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_sync_in_async_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_sync_sort" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST func_optimizer = opt.create_function_optimizer(func) @@ -962,6 +990,10 @@ async def test_mixed_sorting(): test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_mixed_sort_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_mixed_sorting" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" test_type = TestType.EXISTING_UNIT_TEST func_optimizer = opt.create_function_optimizer(async_func) diff --git a/tests/test_async_wrapper_sqlite_validation.py b/tests/test_async_wrapper_sqlite_validation.py index 4386ba5ab..5cf7252f6 100644 --- a/tests/test_async_wrapper_sqlite_validation.py +++ b/tests/test_async_wrapper_sqlite_validation.py @@ -19,11 +19,15 @@ class TestAsyncWrapperSQLiteValidation: @pytest.fixture - def test_env_setup(self): + def test_env_setup(self, request): original_env = {} test_env = { "CODEFLASH_LOOP_INDEX": "1", "CODEFLASH_TEST_ITERATION": "0", + "CODEFLASH_TEST_MODULE": __name__, + "CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation", + "CODEFLASH_TEST_FUNCTION": request.node.name, + "CODEFLASH_CURRENT_LINE_ID": "test_unit", } for key, value in test_env.items(): @@ -57,6 +61,7 @@ async def simple_async_add(a: int, b: int) -> int: await asyncio.sleep(0.001) return a + b + os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59' result = await simple_async_add(5, 3) assert result == 8 @@ -278,10 +283,3 @@ async def schema_test_func() -> str: assert columns == expected_columns con.close() - def test_sync_test_context_extraction(self): - from codeflash.code_utils.codeflash_wrap_decorator import extract_test_context_from_frame - - test_module, test_class, test_func = extract_test_context_from_frame() - assert test_module == __name__ - assert test_class == "TestAsyncWrapperSQLiteValidation" - assert test_func == "test_sync_test_context_extraction" diff --git a/tests/test_critic.py b/tests/test_critic.py index 27df4dde9..3004c53d0 100644 --- a/tests/test_critic.py +++ b/tests/test_critic.py @@ -14,7 +14,13 @@ TestResults, TestType, ) -from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic +from codeflash.result.critic import ( + coverage_critic, + performance_gain, + quantity_of_tests_critic, + speedup_critic, + throughput_gain, +) def test_performance_gain() -> None: @@ -429,3 +435,159 @@ def test_coverage_critic() -> None: ) assert coverage_critic(unittest_coverage, "unittest") is True + + +def test_throughput_gain() -> None: + """Test throughput_gain calculation.""" + # Test basic throughput improvement + assert throughput_gain(original_throughput=100, optimized_throughput=150) == 0.5 # 50% improvement + + # Test no improvement + assert throughput_gain(original_throughput=100, optimized_throughput=100) == 0.0 + + # Test regression + assert throughput_gain(original_throughput=100, optimized_throughput=80) == -0.2 # 20% regression + + # Test zero original throughput (edge case) + assert throughput_gain(original_throughput=0, optimized_throughput=50) == 0.0 + + # Test large improvement + assert throughput_gain(original_throughput=50, optimized_throughput=200) == 3.0 # 300% improvement + + +def test_speedup_critic_with_async_throughput() -> None: + """Test speedup_critic with async throughput evaluation.""" + original_code_runtime = 10000 # 10 microseconds + original_async_throughput = 100 + + # Test case 1: Both runtime and throughput improve significantly + candidate_result = OptimizedCandidateResult( + max_loop_count=5, + best_test_runtime=8000, # 20% runtime improvement + behavior_test_results=TestResults(), + benchmarking_test_results=TestResults(), + optimization_candidate_index=0, + total_candidate_timing=8000, + async_throughput=120, # 20% throughput improvement + ) + + assert speedup_critic( + candidate_result=candidate_result, + original_code_runtime=original_code_runtime, + best_runtime_until_now=None, + original_async_throughput=original_async_throughput, + best_throughput_until_now=None, + disable_gh_action_noise=True + ) + + # Test case 2: Runtime improves significantly, throughput doesn't meet threshold (should pass) + candidate_result = OptimizedCandidateResult( + max_loop_count=5, + best_test_runtime=8000, # 20% runtime improvement + behavior_test_results=TestResults(), + benchmarking_test_results=TestResults(), + optimization_candidate_index=0, + total_candidate_timing=8000, + async_throughput=105, # Only 5% throughput improvement (below 10% threshold) + ) + + assert speedup_critic( + candidate_result=candidate_result, + original_code_runtime=original_code_runtime, + best_runtime_until_now=None, + original_async_throughput=original_async_throughput, + best_throughput_until_now=None, + disable_gh_action_noise=True + ) + + # Test case 3: Throughput improves significantly, runtime doesn't meet threshold (should pass) + candidate_result = OptimizedCandidateResult( + max_loop_count=5, + best_test_runtime=9800, # Only 2% runtime improvement (below 5% threshold) + behavior_test_results=TestResults(), + benchmarking_test_results=TestResults(), + optimization_candidate_index=0, + total_candidate_timing=9800, + async_throughput=120, # 20% throughput improvement + ) + + assert speedup_critic( + candidate_result=candidate_result, + original_code_runtime=original_code_runtime, + best_runtime_until_now=None, + original_async_throughput=original_async_throughput, + best_throughput_until_now=None, + disable_gh_action_noise=True + ) + + # Test case 4: No throughput data - should fall back to runtime-only evaluation + candidate_result = OptimizedCandidateResult( + max_loop_count=5, + best_test_runtime=8000, # 20% runtime improvement + behavior_test_results=TestResults(), + benchmarking_test_results=TestResults(), + optimization_candidate_index=0, + total_candidate_timing=8000, + async_throughput=None, # No throughput data + ) + + assert speedup_critic( + candidate_result=candidate_result, + original_code_runtime=original_code_runtime, + best_runtime_until_now=None, + original_async_throughput=None, # No original throughput data + best_throughput_until_now=None, + disable_gh_action_noise=True + ) + + # Test case 5: Test best_throughput_until_now comparison + candidate_result = OptimizedCandidateResult( + max_loop_count=5, + best_test_runtime=8000, # 20% runtime improvement + behavior_test_results=TestResults(), + benchmarking_test_results=TestResults(), + optimization_candidate_index=0, + total_candidate_timing=8000, + async_throughput=115, # 15% throughput improvement + ) + + # Should pass when no best throughput yet + assert speedup_critic( + candidate_result=candidate_result, + original_code_runtime=original_code_runtime, + best_runtime_until_now=None, + original_async_throughput=original_async_throughput, + best_throughput_until_now=None, + disable_gh_action_noise=True + ) + + # Should fail when there's a better throughput already + assert not speedup_critic( + candidate_result=candidate_result, + original_code_runtime=original_code_runtime, + best_runtime_until_now=7000, # Better runtime already exists + original_async_throughput=original_async_throughput, + best_throughput_until_now=120, # Better throughput already exists + disable_gh_action_noise=True + ) + + # Test case 6: Zero original throughput (edge case) + candidate_result = OptimizedCandidateResult( + max_loop_count=5, + best_test_runtime=8000, # 20% runtime improvement + behavior_test_results=TestResults(), + benchmarking_test_results=TestResults(), + optimization_candidate_index=0, + total_candidate_timing=8000, + async_throughput=50, + ) + + # Should pass when original throughput is 0 (throughput evaluation skipped) + assert speedup_critic( + candidate_result=candidate_result, + original_code_runtime=original_code_runtime, + best_runtime_until_now=None, + original_async_throughput=0, # Zero original throughput + best_throughput_until_now=None, + disable_gh_action_noise=True + ) diff --git a/tests/test_extract_test_context_from_frame.py b/tests/test_extract_test_context_from_frame.py deleted file mode 100644 index f33a65fa6..000000000 --- a/tests/test_extract_test_context_from_frame.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from unittest.mock import Mock, patch - -import pytest - -from codeflash.code_utils.codeflash_wrap_decorator import ( - _extract_class_name_tracer, - _get_module_name_cf_tracer, - extract_test_context_from_frame, -) - - -@pytest.fixture -def mock_instance(): - mock_obj = Mock() - mock_obj.__class__.__name__ = "TestClassName" - return mock_obj - - -@pytest.fixture -def mock_class(): - mock_cls = Mock() - mock_cls.__name__ = "TestClassMethod" - return mock_cls - - -class TestExtractClassNameTracer: - - def test_extract_class_name_with_self(self, mock_instance): - frame_locals = {"self": mock_instance} - result = _extract_class_name_tracer(frame_locals) - - assert result == "TestClassName" - - def test_extract_class_name_with_cls(self, mock_class): - frame_locals = {"cls": mock_class} - result = _extract_class_name_tracer(frame_locals) - - assert result == "TestClassMethod" - - def test_extract_class_name_self_no_class(self, mock_class): - class NoClassMock: - @property - def __class__(self): - raise AttributeError("no __class__ attribute") - - mock_instance = NoClassMock() - frame_locals = {"self": mock_instance, "cls": mock_class} - result = _extract_class_name_tracer(frame_locals) - - assert result == "TestClassMethod" - - def test_extract_class_name_no_self_or_cls(self): - frame_locals = {"some_var": "value"} - result = _extract_class_name_tracer(frame_locals) - - assert result is None - - def test_extract_class_name_exception_handling(self): - class ExceptionMock: - @property - def __class__(self): - raise Exception("Test exception") - - mock_instance = ExceptionMock() - frame_locals = {"self": mock_instance} - result = _extract_class_name_tracer(frame_locals) - - assert result is None - - def test_extract_class_name_with_attribute_error(self): - class AttributeErrorMock: - @property - def __class__(self): - raise AttributeError("Wrapt-like error") - - mock_instance = AttributeErrorMock() - frame_locals = {"self": mock_instance} - result = _extract_class_name_tracer(frame_locals) - - assert result is None - - -class TestGetModuleNameCfTracer: - - def test_get_module_name_with_valid_frame(self): - mock_frame = Mock() - mock_module = Mock() - mock_module.__name__ = "test_module_name" - - with patch("inspect.getmodule", return_value=mock_module): - result = _get_module_name_cf_tracer(mock_frame) - assert result == "test_module_name" - - def test_get_module_name_from_frame_globals(self): - mock_frame = Mock() - mock_frame.f_globals = {"__name__": "module_from_globals"} - - with patch("inspect.getmodule", side_effect=Exception("Module not found")): - result = _get_module_name_cf_tracer(mock_frame) - assert result == "module_from_globals" - - def test_get_module_name_no_name_in_globals(self): - mock_frame = Mock() - mock_frame.f_globals = {} - - with patch("inspect.getmodule", side_effect=Exception("Module not found")): - result = _get_module_name_cf_tracer(mock_frame) - assert result == "unknown_module" - - def test_get_module_name_none_frame(self): - result = _get_module_name_cf_tracer(None) - assert result == "unknown_module" - - def test_get_module_name_module_no_name_attribute(self): - mock_frame = Mock() - mock_module = Mock(spec=[]) - mock_frame.f_globals = {"__name__": "fallback_name"} - - with patch("inspect.getmodule", return_value=mock_module): - result = _get_module_name_cf_tracer(mock_frame) - assert result == "fallback_name" - - -class TestExtractTestContextFromFrame: - - def test_direct_test_function_call(self): - def test_example_function(): - return extract_test_context_from_frame() - - result = test_example_function() - module_name, class_name, function_name = result - - assert module_name == __name__ - assert class_name == "TestExtractTestContextFromFrame" - assert function_name == "test_example_function" - - def test_with_test_class_method(self): - class TestExampleClass: - def test_method(self): - return extract_test_context_from_frame() - - instance = TestExampleClass() - result = instance.test_method() - module_name, class_name, function_name = result - - assert module_name == __name__ - assert class_name == "TestExampleClass" - assert function_name == "test_method" - - def test_function_without_test_prefix(self): - result = extract_test_context_from_frame() - module_name, class_name, function_name = result - - assert module_name == __name__ - assert class_name == "TestExtractTestContextFromFrame" - assert function_name == "test_function_without_test_prefix" - - @patch('inspect.currentframe') - def test_no_test_context_raises_runtime_error(self, mock_current_frame): - mock_frame = Mock() - mock_frame.f_back = None - mock_frame.f_code.co_name = "regular_function" - mock_frame.f_code.co_filename = "/path/to/regular_file.py" - mock_frame.f_locals = {} - mock_frame.f_globals = {"__name__": "regular_module"} - - mock_current_frame.return_value = mock_frame - - with pytest.raises(RuntimeError, match="No test function found in call stack"): - extract_test_context_from_frame() - - def test_real_call_stack_context(self): - def nested_function(): - def deeper_function(): - return extract_test_context_from_frame() - return deeper_function() - - result = nested_function() - module_name, class_name, function_name = result - - assert module_name == __name__ - assert class_name == "TestExtractTestContextFromFrame" - assert function_name == "test_real_call_stack_context" - - - -class TestIntegrationScenarios: - - def test_pytest_class_method_scenario(self): - class TestExampleIntegration: - def test_integration_method(self): - return extract_test_context_from_frame() - - instance = TestExampleIntegration() - result = instance.test_integration_method() - module_name, class_name, function_name = result - - assert module_name == __name__ - assert class_name == "TestExampleIntegration" - assert function_name == "test_integration_method" - - def test_nested_helper_functions(self): - def outer_helper(): - def inner_helper(): - def deepest_helper(): - return extract_test_context_from_frame() - return deepest_helper() - return inner_helper() - - result = outer_helper() - module_name, class_name, function_name = result - - assert module_name == __name__ - assert class_name == "TestIntegrationScenarios" - assert function_name == "test_nested_helper_functions" diff --git a/tests/test_instrument_async_tests.py b/tests/test_instrument_async_tests.py index 97c4dd659..cdce5bf82 100644 --- a/tests/test_instrument_async_tests.py +++ b/tests/test_instrument_async_tests.py @@ -535,3 +535,194 @@ def test_qualified_name_with_nested_parents(): is_async=False ) assert func_mixed_parents.qualified_name == 'MyClass.outer_function.inner_function' + + +def test_inject_profiling_async_multiple_calls_same_test(temp_dir): + """Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc.""" + source_module_code = ''' +import asyncio + +async def async_sorter(items): + """Simple async sorter for testing.""" + await asyncio.sleep(0.001) + return sorted(items) +''' + + source_file = temp_dir / "async_sorter.py" + source_file.write_text(source_module_code) + + test_code_multiple_calls = ''' +import asyncio +import pytest +from async_sorter import async_sorter + +@pytest.mark.asyncio +async def test_single_call(): + result = await async_sorter([42]) + assert result == [42] + +@pytest.mark.asyncio +async def test_multiple_calls(): + result1 = await async_sorter([3, 1, 2]) + result2 = await async_sorter([5, 4]) + result3 = await async_sorter([9, 8, 7, 6]) + assert result1 == [1, 2, 3] + assert result2 == [4, 5] + assert result3 == [6, 7, 8, 9] +''' + + test_file = temp_dir / "test_async_sorter.py" + test_file.write_text(test_code_multiple_calls) + + func = FunctionToOptimize( + function_name="async_sorter", + parents=[], + file_path=Path("async_sorter.py"), + is_async=True + ) + + # First instrument the source module with async decorators + from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators + source_success, instrumented_source = instrument_source_module_with_async_decorators( + source_file, func, TestingMode.BEHAVIOR + ) + + assert source_success + assert instrumented_source is not None + assert '@codeflash_behavior_async' in instrumented_source + + # Write the instrumented source back + source_file.write_text(instrumented_source) + + # Now test injection with multiple call positions + # Parse the test file to get exact positions for async calls + import ast + tree = ast.parse(test_code_multiple_calls) + call_positions = [] + for node in ast.walk(tree): + if isinstance(node, ast.Await) and isinstance(node.value, ast.Call): + if hasattr(node.value.func, 'id') and node.value.func.id == 'async_sorter': + call_positions.append(CodePosition(node.lineno, node.col_offset)) + elif hasattr(node.value.func, 'attr') and node.value.func.attr == 'async_sorter': + call_positions.append(CodePosition(node.lineno, node.col_offset)) + + # Should find 4 calls total: 1 in test_single_call + 3 in test_multiple_calls + assert len(call_positions) == 4 + + success, instrumented_test_code = inject_profiling_into_existing_test( + test_file, + call_positions, + func, + temp_dir, + "pytest", + mode=TestingMode.BEHAVIOR + ) + + assert success + assert instrumented_test_code is not None + + # Verify the instrumentation adds correct line_id assignments + # Each test function should start from 0 + assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code + + # Count occurrences of each line_id to verify numbering + line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'") + line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'") + line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'") + + # Should have: + # - 2 occurrences of '0' (first call in each test function) + # - 1 occurrence of '1' (second call in test_multiple_calls) + # - 1 occurrence of '2' (third call in test_multiple_calls) + assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}" + assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}" + assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}" + + # Verify no higher numbers + line_id_3_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '3'") + assert line_id_3_count == 0, f"Unexpected occurrence of line_id '3'" + + # Check that imports are added + assert 'import os' in instrumented_test_code + + +def test_sync_functions_do_not_get_async_instrumentation(temp_dir): + """Test that sync functions do NOT get async instrumentation (os.environ assignments).""" + # Create a sync function module + sync_module_code = ''' +def sync_sorter(items): + """Simple sync sorter for testing.""" + return sorted(items) +''' + + source_file = temp_dir / "sync_sorter.py" + source_file.write_text(sync_module_code) + + # Create test code with sync function calls + sync_test_code = ''' +import pytest +from sync_sorter import sync_sorter + +def test_single_call(): + result = sync_sorter([42]) + assert result == [42] + +def test_multiple_calls(): + result1 = sync_sorter([3, 1, 2]) + result2 = sync_sorter([5, 4]) + result3 = sync_sorter([9, 8, 7, 6]) + assert result1 == [1, 2, 3] + assert result2 == [4, 5] + assert result3 == [6, 7, 8, 9] +''' + + test_file = temp_dir / "test_sync_sorter.py" + test_file.write_text(sync_test_code) + + sync_func = FunctionToOptimize( + function_name="sync_sorter", + parents=[], + file_path=Path("sync_sorter.py"), + is_async=False # SYNC function + ) + + # Parse the test file to get exact positions for sync calls + import ast + tree = ast.parse(sync_test_code) + call_positions = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call): + if hasattr(node.func, 'id') and node.func.id == 'sync_sorter': + call_positions.append(CodePosition(node.lineno, node.col_offset)) + elif hasattr(node.func, 'attr') and node.func.attr == 'sync_sorter': + call_positions.append(CodePosition(node.lineno, node.col_offset)) + + # Should find 4 calls total: 1 in test_single_call + 3 in test_multiple_calls + assert len(call_positions) == 4 + + success, instrumented_test_code = inject_profiling_into_existing_test( + test_file, + call_positions, + sync_func, + temp_dir, + "pytest", + mode=TestingMode.BEHAVIOR + ) + + assert success + assert instrumented_test_code is not None + + # Verify the sync function does NOT get async instrumentation + assert "os.environ['CODEFLASH_CURRENT_LINE_ID']" not in instrumented_test_code + + # But should get proper sync instrumentation + assert 'codeflash_wrap' in instrumented_test_code + assert 'codeflash_loop_index' in instrumented_test_code + assert 'sqlite3' in instrumented_test_code # sync behavior mode includes sqlite + + # Verify the line_id values are correct for sync functions (statement-based) + # Sync functions use statement index, not per-test-function counter + assert "'0'" in instrumented_test_code # first call in test_single_call + assert "'0'" in instrumented_test_code # first call in test_multiple_calls (second occurrence) + assert "'1'" in instrumented_test_code # second call in test_multiple_calls + assert "'2'" in instrumented_test_code # third call in test_multiple_calls