diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index a731bf5a6..e15333d75 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -4,15 +4,19 @@ import os import platform import time -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast import requests from pydantic.json import pydantic_encoder from codeflash.cli_cmds.console import console, logger +from codeflash.code_utils.code_replacer import is_zero_diff +from codeflash.code_utils.code_utils import unified_diff_strings from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE, N_CANDIDATES_LP_EFFECTIVE from codeflash.code_utils.env_utils import get_codeflash_api_key from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name +from codeflash.code_utils.time_utils import humanize_runtime from codeflash.lsp.helpers import is_LSP_enabled from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate @@ -20,11 +24,10 @@ from codeflash.version import __version__ as codeflash_version if TYPE_CHECKING: - from pathlib import Path - from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import AIServiceRefinerRequest + from codeflash.result.explanation import Explanation class AiServiceClient: @@ -529,6 +532,85 @@ def generate_regression_tests( # noqa: D417 ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text}) return None + def get_optimization_impact( + self, + original_code: dict[Path, str], + new_code: dict[Path, str], + explanation: Explanation, + existing_tests_source: str, + generated_original_test_source: str, + function_trace_id: str, + coverage_message: str, + replay_tests: str, + root_dir: Path, + concolic_tests: str, # noqa: ARG002 + ) -> str: + """Compute the optimization impact of current Pull Request. + + Args: + original_code: dict -> data structure mapping file paths to function definition for original code + new_code: dict -> data structure mapping file paths to function definition for optimized code + explanation: Explanation -> data structure containing runtime information + existing_tests_source: str -> existing tests table + generated_original_test_source: str -> annotated generated tests + function_trace_id: str -> traceid of function + coverage_message: str -> coverage information + replay_tests: str -> replay test table + root_dir: Path -> path of git directory + concolic_tests: str -> concolic_tests (not used) + + Returns: + ------- + - 'high' or 'low' optimization impact + + """ + diff_str = "\n".join( + [ + unified_diff_strings( + code1=original_code[p], + code2=new_code[p], + fromfile=Path(p).relative_to(root_dir).as_posix(), + tofile=Path(p).relative_to(root_dir).as_posix(), + ) + for p in original_code + if not is_zero_diff(original_code[p], new_code[p]) + ] + ) + code_diff = f"```diff\n{diff_str}\n```" + logger.info("!lsp|Computing Optimization Impact…") + payload = { + "code_diff": code_diff, + "explanation": explanation.raw_explanation_message, + "existing_tests": existing_tests_source, + "generated_tests": generated_original_test_source, + "trace_id": function_trace_id, + "coverage_message": coverage_message, + "replay_tests": replay_tests, + "speedup": f"{(100 * float(explanation.speedup)):.2f}%", + "loop_count": explanation.winning_benchmarking_test_results.number_of_loops(), + "benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None, + "optimized_runtime": humanize_runtime(explanation.best_runtime_ns), + "original_runtime": humanize_runtime(explanation.original_runtime_ns), + } + console.rule() + try: + response = self.make_ai_service_request("/optimization_impact", payload=payload, timeout=600) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimization refinements: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return "" + + if response.status_code == 200: + return cast("str", response.json()["impact"]) + try: + error = cast("str", response.json()["error"]) + except Exception: + error = response.text + logger.error(f"Error generating impact candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return "" + class LocalAiServiceClient(AiServiceClient): """Client for interacting with the local AI service.""" diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 31fc688c9..58f0e172f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1458,14 +1458,22 @@ def process_review( } raise_pr = not self.args.no_pr + staging_review = self.args.staging_review - if raise_pr or self.args.staging_review: + if raise_pr or staging_review: data["root_dir"] = git_root_dir() - - if raise_pr and not self.args.staging_review: + try: + # modify argument of staging vs pr based on the impact + opt_impact_response = self.aiservice_client.get_optimization_impact(**data) + if opt_impact_response == "low": + raise_pr = False + staging_review = True + except Exception as e: + logger.debug(f"optimization impact response failed, investigate {e}") + if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) - elif self.args.staging_review: + elif staging_review: response = create_staging(**data) if response.status_code == 200: staging_url = f"https://app.codeflash.ai/review-optimizations/{self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id}" @@ -1504,7 +1512,7 @@ def process_review( self.revert_code_and_helpers(original_helper_code) return - if self.args.staging_review: + if staging_review: # always revert code and helpers when staging review self.revert_code_and_helpers(original_helper_code) return