diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index f51d57ac0..7082ab104 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -4,7 +4,6 @@ import json import os -import sys from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -89,12 +88,13 @@ def make_cfapi_request( @lru_cache(maxsize=1) -def get_user_id(api_key: Optional[str] = None) -> Optional[str]: +def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911 """Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint. :param api_key: The API key to use. If None, uses get_codeflash_api_key(). :return: The userid or None if the request fails. """ + lsp_enabled = is_LSP_enabled() if not api_key and not ensure_codeflash_api_key(): return None @@ -115,19 +115,21 @@ def get_user_id(api_key: Optional[str] = None) -> Optional[str]: if min_version and version.parse(min_version) > version.parse(__version__): msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`." console.print(f"[bold red]{msg}[/bold red]") - if is_LSP_enabled(): + if lsp_enabled: logger.debug(msg) return f"Error: {msg}" - sys.exit(1) + exit_with_message(msg, error_on_exit=True) return userid logger.error("Failed to retrieve userid from the response.") return None - # Handle 403 (Invalid API key) - exit with error message if response.status_code == 403: + error_title = "Invalid Codeflash API key. The API key you provided is not valid." + if lsp_enabled: + return f"Error: {error_title}" msg = ( - "Invalid Codeflash API key. The API key you provided is not valid.\n" + f"{error_title}\n" "Please generate a new one at https://app.codeflash.ai/app/apikeys ,\n" "then set it as a CODEFLASH_API_KEY environment variable.\n" "For more information, refer to the documentation at \n" diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 507b5a169..37e0dd94e 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -17,6 +17,7 @@ from codeflash.cli_cmds.console import logger, paneled_text from codeflash.code_utils.config_parser import find_pyproject_toml, get_all_closest_config_files +from codeflash.lsp.helpers import is_LSP_enabled ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) @@ -352,6 +353,10 @@ def restore_conftest(path_to_content_map: dict[Path, str]) -> None: def exit_with_message(message: str, *, error_on_exit: bool = False) -> None: + """Don't Call it inside the lsp process, it will terminate the lsp server.""" + if is_LSP_enabled(): + logger.error(message) + return paneled_text(message, panel_args={"style": "red"}) sys.exit(1 if error_on_exit else 0) diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 4200edb7d..f4991bef0 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -13,6 +13,7 @@ from codeflash.code_utils.code_utils import exit_with_message from codeflash.code_utils.formatter import format_code from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc +from codeflash.lsp.helpers import is_LSP_enabled def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa @@ -70,7 +71,10 @@ def get_codeflash_api_key() -> str: except Exception as e: logger.debug(f"Failed to automatically save API key to shell config: {e}") - api_key = env_api_key or shell_api_key + # Prefer the shell configuration over environment variables for lsp, + # as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart + # within the same process, the environment variable could become outdated. + api_key = shell_api_key or env_api_key if is_LSP_enabled() else env_api_key or shell_api_key api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa if not api_key: diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 240d44880..46f898db7 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -72,11 +72,6 @@ class ValidateProjectParams: skip_validation: bool = False -@dataclass -class OnPatchAppliedParams: - task_id: str - - @dataclass class OptimizableFunctionsInCommitParams: commit_hash: str @@ -91,6 +86,11 @@ class WriteConfigParams: server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) +@server.feature("server/listFeatures") +def list_features(_params: any) -> list[str]: + return list(server.protocol.fm.features) + + @server.feature("getOptimizableFunctionsInCurrentDiff") def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]: functions = get_functions_within_git_diff(uncommitted_changes=True) @@ -251,7 +251,7 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]: "existingConfig": config, } - args = _init() + args = process_args() return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root} @@ -269,8 +269,9 @@ def _check_api_key_validity(api_key: Optional[str]) -> dict[str, str]: if user_id is None: return {"status": "error", "message": "api key not found or invalid"} - if user_id.startswith("Error: "): - error_msg = user_id[7:] + error_prefix = "Error: " + if user_id.startswith(error_prefix): + error_msg = user_id[len(error_prefix) :] return {"status": "error", "message": error_msg} return {"status": "success", "user_id": user_id}