diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 97c80f656..fad6eaa4d 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -2,6 +2,7 @@ import json import os +from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -26,14 +27,24 @@ from packaging import version -if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local": - CFAPI_BASE_URL = "http://localhost:3001" - CFWEBAPP_BASE_URL = "http://localhost:3000" - logger.info(f"Using local CF API at {CFAPI_BASE_URL}.") - console.rule() -else: - CFAPI_BASE_URL = "https://app.codeflash.ai" - CFWEBAPP_BASE_URL = "https://app.codeflash.ai" + +@dataclass +class BaseUrls: + cfapi_base_url: Optional[str] = None + cfwebapp_base_url: Optional[str] = None + + +@lru_cache(maxsize=1) +def get_cfapi_base_urls() -> BaseUrls: + if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local": + cfapi_base_url = "http://localhost:3001" + cfwebapp_base_url = "http://localhost:3000" + logger.info(f"Using local CF API at {cfapi_base_url}.") + console.rule() + else: + cfapi_base_url = "https://app.codeflash.ai" + cfwebapp_base_url = "https://app.codeflash.ai" + return BaseUrls(cfapi_base_url=cfapi_base_url, cfwebapp_base_url=cfwebapp_base_url) def make_cfapi_request( @@ -53,8 +64,9 @@ def make_cfapi_request( :param suppress_errors: If True, suppress error logging for HTTP errors. :return: The response object from the API. """ - url = f"{CFAPI_BASE_URL}/cfapi{endpoint}" - cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"} + url = f"{get_cfapi_base_urls().cfapi_base_url}/cfapi{endpoint}" + final_api_key = api_key or get_codeflash_api_key() + cfapi_headers = {"Authorization": f"Bearer {final_api_key}"} if extra_headers: cfapi_headers.update(extra_headers) try: diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index bbc83d943..fca5313c4 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -167,8 +167,8 @@ def ask_run_end_to_end_test(args: Namespace) -> None: console.rule() if run_tests: - bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args) - run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path) + file_path = create_find_common_tags_file(args, "find_common_tags.py") + run_end_to_end_test(args, file_path) def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911 @@ -1207,6 +1207,35 @@ def enter_api_key_and_save_to_rc() -> None: os.environ["CODEFLASH_API_KEY"] = api_key +def create_find_common_tags_file(args: Namespace, file_name: str) -> Path: + find_common_tags_content = """def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]: + if not articles: + return set() + + common_tags = articles[0]["tags"] + for article in articles[1:]: + common_tags = [tag for tag in common_tags if tag in article["tags"]] + return set(common_tags) +""" + + file_path = Path(args.module_root) / file_name + lsp_enabled = is_LSP_enabled() + if file_path.exists() and not lsp_enabled: + from rich.prompt import Confirm + + overwrite = Confirm.ask( + f"๐Ÿค” {file_path} already exists. Do you want to overwrite it?", default=True, show_default=False + ) + if not overwrite: + apologize_and_exit() + console.rule() + + file_path.write_text(find_common_tags_content, encoding="utf8") + logger.info(f"Created demo optimization file: {file_path}") + + return file_path + + def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]: bubble_sort_content = """from typing import Union, List def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]: @@ -1276,7 +1305,7 @@ def test_sort(): return str(bubble_sort_path), str(bubble_sort_test_path) -def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None: +def run_end_to_end_test(args: Namespace, find_common_tags_path: Path) -> None: try: check_formatter_installed(args.formatter_cmds) except Exception: @@ -1285,7 +1314,7 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test ) return - command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"] + command = ["codeflash", "--file", "find_common_tags.py", "--function", "find_common_tags"] if args.no_pr: command.append("--no-pr") if args.verbose: @@ -1316,10 +1345,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test console.rule() # Delete the bubble_sort.py file after the test logger.info("๐Ÿงน Cleaning upโ€ฆ") - for path in [bubble_sort_path, bubble_sort_test_path]: - console.rule() - Path(path).unlink(missing_ok=True) - logger.info(f"๐Ÿ—‘๏ธ Deleted {path}") + find_common_tags_path.unlink(missing_ok=True) + logger.info(f"๐Ÿ—‘๏ธ Deleted {find_common_tags_path}") def ask_for_telemetry() -> bool: diff --git a/codeflash/code_utils/git_worktree_utils.py b/codeflash/code_utils/git_worktree_utils.py index e63690a66..11faa8902 100644 --- a/codeflash/code_utils/git_worktree_utils.py +++ b/codeflash/code_utils/git_worktree_utils.py @@ -81,7 +81,7 @@ def remove_worktree(worktree_dir: Path) -> None: def create_diff_patch_from_worktree( - worktree_dir: Path, files: list[str], fto_name: Optional[str] = None + worktree_dir: Path, files: list[Path], fto_name: Optional[str] = None ) -> Optional[Path]: repository = git.Repo(worktree_dir, search_parent_directories=True) uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True) diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 46f898db7..9d96eab41 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -7,7 +7,7 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from codeflash.api.cfapi import get_codeflash_api_key, get_user_id from codeflash.cli_cmds.cli import process_pyproject_config @@ -16,12 +16,14 @@ VsCodeSetupInfo, configure_pyproject_toml, create_empty_pyproject_toml, + create_find_common_tags_file, get_formatter_cmds, get_suggestions, get_valid_subdirs, is_valid_pyproject_toml, ) from codeflash.code_utils.git_utils import git_root_dir +from codeflash.code_utils.git_worktree_utils import create_worktree_snapshot_commit from codeflash.code_utils.shell_utils import save_api_key_to_rc from codeflash.discovery.functions_to_optimize import ( filter_functions, @@ -39,6 +41,7 @@ from lsprotocol import types from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.lsp.server import WrappedInitializationResultT @dataclass @@ -55,11 +58,15 @@ class FunctionOptimizationInitParams: @dataclass class FunctionOptimizationParams: - textDocument: types.TextDocumentIdentifier # noqa: N815 functionName: str # noqa: N815 task_id: str +@dataclass +class DemoOptimizationParams: + functionName: str # noqa: N815 + + @dataclass class ProvideApiKeyParams: api_key: str @@ -257,10 +264,8 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]: def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]: key_check_result = _check_api_key_validity(api_key) - if key_check_result.get("status") != "success": - return key_check_result - - _init() + if key_check_result.get("status") == "success": + _init() return key_check_result @@ -303,8 +308,8 @@ def _init() -> Namespace: def check_api_key(_params: any) -> dict[str, str]: try: return _initialize_optimizer_if_api_key_is_valid() - except Exception: - return {"status": "error", "message": "something went wrong while validating the api key"} + except Exception as ex: + return {"status": "error", "message": "something went wrong while validating the api key " + str(ex)} @server.feature("provideApiKey") @@ -353,6 +358,56 @@ def cleanup_optimizer(_params: any) -> dict[str, str]: return {"status": "success"} +def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedInitializationResultT]: + """Initialize the current function optimizer. + + Returns: + Union[dict[str, str], WrappedInitializationResultT]: + error dict with status error, + or a wrapped initializationresult if the optimizer is initialized. + + """ + if not server.optimizer: + return {"status": "error", "message": "Optimizer not initialized yet."} + + function_name = server.optimizer.args.function + optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions() + + if count == 0: + server.show_message_log(f"No optimizable functions found for {function_name}", "Warning") + server.cleanup_the_optimizer() + return {"functionName": function_name, "status": "error", "message": "not found", "args": None} + + fto = optimizable_funcs.popitem()[1][0] + + module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path) + if not module_prep_result: + return { + "functionName": function_name, + "status": "error", + "message": "Failed to prepare module for optimization", + } + + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = server.optimizer.create_function_optimizer( + fto, + function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=fto.file_path, + function_to_tests={}, + ) + + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": function_name, "status": "error", "message": "No function optimizer found"} + + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": function_name, "status": "error", "message": initialization_result.failure()} + return initialization_result + + @server.feature("initializeFunctionOptimization") def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]: with execution_context(task_id=getattr(params, "task_id", None)): @@ -377,45 +432,14 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) -> f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" ) - optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions() - - if count == 0: - server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning") - server.cleanup_the_optimizer() - return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None} - - fto = optimizable_funcs.popitem()[1][0] - - module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path) - if not module_prep_result: - return { - "functionName": params.functionName, - "status": "error", - "message": "Failed to prepare module for optimization", - } - - validated_original_code, original_module_ast = module_prep_result - - function_optimizer = server.optimizer.create_function_optimizer( - fto, - function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, - original_module_ast=original_module_ast, - original_module_path=fto.file_path, - function_to_tests={}, - ) - - server.optimizer.current_function_optimizer = function_optimizer - if not function_optimizer: - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} - - initialization_result = function_optimizer.can_be_optimized() - if not is_successful(initialization_result): - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + initialization_result = _initialize_current_function_optimizer() + if isinstance(initialization_result, dict): + return initialization_result server.current_optimization_init_result = initialization_result.unwrap() server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info") - files = [function_optimizer.function_to_optimize.file_path] + files = [document.path] _, _, original_helpers = server.current_optimization_init_result files.extend([str(helper_path) for helper_path in original_helpers]) @@ -423,6 +447,32 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) -> return {"functionName": params.functionName, "status": "success", "files_inside_context": files} +@server.feature("startDemoOptimization") +async def start_demo_optimization(params: DemoOptimizationParams) -> dict[str, str]: + try: + _init() + # start by creating the worktree so that the demo file is not created in user workspace + server.optimizer.worktree_mode() + file_path = create_find_common_tags_file(server.args, params.functionName + ".py") + # commit the new file for diff generation later + create_worktree_snapshot_commit(server.optimizer.current_worktree, "added sample optimization file") + + server.optimizer.args.file = file_path + server.optimizer.args.function = params.functionName + server.optimizer.args.previous_checkpoint_functions = False + + initialization_result = _initialize_current_function_optimizer() + if isinstance(initialization_result, dict): + return initialization_result + + server.current_optimization_init_result = initialization_result.unwrap() + return await perform_function_optimization( + FunctionOptimizationParams(functionName=params.functionName, task_id=None) + ) + finally: + server.cleanup_the_optimizer() + + @server.feature("performFunctionOptimization") async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]: with execution_context(task_id=getattr(params, "task_id", None)): diff --git a/codeflash/lsp/features/perform_optimization.py b/codeflash/lsp/features/perform_optimization.py index c29bd2f27..cd5b36d8b 100644 --- a/codeflash/lsp/features/perform_optimization.py +++ b/codeflash/lsp/features/perform_optimization.py @@ -59,6 +59,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr generated_perf_test_paths, instrumented_unittests_created_for_function, original_conftest_content, + function_references, ) = test_setup_result.unwrap() baseline_setup_result = function_optimizer.setup_and_establish_baseline( @@ -94,6 +95,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove, concolic_test_str=concolic_test_str, + function_references=function_references, ) abort_if_cancelled(cancel_event) diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py index e3e45c835..582e5033c 100644 --- a/codeflash/lsp/server.py +++ b/codeflash/lsp/server.py @@ -1,15 +1,16 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING from lsprotocol.types import LogMessageParams, MessageType from pygls.lsp.server import LanguageServer from pygls.protocol import LanguageServerProtocol -if TYPE_CHECKING: - from pathlib import Path +from codeflash.either import Result +from codeflash.models.models import CodeOptimizationContext - from codeflash.models.models import CodeOptimizationContext +if TYPE_CHECKING: from codeflash.optimization.optimizer import Optimizer @@ -17,6 +18,10 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol): _server: CodeflashLanguageServer +InitializationResultT = tuple[bool, CodeOptimizationContext, dict[Path, str]] +WrappedInitializationResultT = Result[InitializationResultT, str] + + class CodeflashLanguageServer(LanguageServer): def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerProtocol]) -> None: super().__init__(name, version, protocol_cls=protocol_cls) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 3a0629adb..2c46d4ce5 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -19,7 +19,7 @@ from rich.tree import Tree from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient -from codeflash.api.cfapi import CFWEBAPP_BASE_URL, add_code_context_hash, create_staging, mark_optimization_success +from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils @@ -1195,7 +1195,7 @@ def generate_tests_and_optimizations( if concolic_test_str: count_tests += 1 - logger.info(f"Generated '{count_tests}' tests for {self.function_to_optimize.function_name}") + logger.info(f"Generated '{count_tests}' tests for '{self.function_to_optimize.function_name}'") console.rule() generated_tests = GeneratedTestsList(generated_tests=tests) result = ( @@ -1508,7 +1508,7 @@ def process_review( response = create_staging(**data) if response.status_code == 200: trace_id = self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id - staging_url = f"{CFWEBAPP_BASE_URL}/review-optimizations/{trace_id}" + staging_url = f"{get_cfapi_base_urls().cfwebapp_base_url}/review-optimizations/{trace_id}" console.print( Panel( f"[bold green]โœ… Staging created:[/bold green]\n[link={staging_url}]{staging_url}[/link]",