diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index a7bc8a1fa..6d97764e7 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -42,6 +42,7 @@ def get_codeflash_api_key() -> str: if env_api_key and not shell_api_key: try: from codeflash.either import is_successful + result = save_api_key_to_rc(env_api_key) if is_successful(result): logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}") diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 03719be4e..c78e551be 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -24,7 +24,7 @@ get_functions_within_git_diff, ) from codeflash.either import is_successful -from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol +from codeflash.lsp.server import CodeflashLanguageServer if TYPE_CHECKING: from argparse import Namespace @@ -50,6 +50,13 @@ class ProvideApiKeyParams: api_key: str +@dataclass +class ValidateProjectParams: + root_path_abs: str + config_file: Optional[str] = None + skip_validation: bool = False + + @dataclass class OnPatchAppliedParams: patch_id: str @@ -60,7 +67,8 @@ class OptimizableFunctionsInCommitParams: commit_hash: str -server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) +# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) +server = CodeflashLanguageServer("codeflash-language-server", "v1.0") @server.feature("getOptimizableFunctionsInCurrentDiff") @@ -160,17 +168,60 @@ def initialize_function_optimization( return {"functionName": params.functionName, "status": "success"} -@server.feature("validateProject") -def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]: +def _find_pyproject_toml(workspace_path: str) -> Path | None: + workspace_path_obj = Path(workspace_path) + max_depth = 2 + base_depth = len(workspace_path_obj.parts) + + for root, dirs, files in os.walk(workspace_path_obj): + depth = len(Path(root).parts) - base_depth + if depth > max_depth: + # stop going deeper into this branch + dirs.clear() + continue + + if "pyproject.toml" in files: + file_path = Path(root) / "pyproject.toml" + with file_path.open("r", encoding="utf-8", errors="ignore") as f: + for line in f: + if line.strip() == "[tool.codeflash]": + return file_path.resolve() + return None + + +# should be called the first thing to initialize and validate the project +@server.feature("initProject") +def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml + pyproject_toml_path: Path | None = getattr(params, "config_file", None) + + if server.args is None: + if pyproject_toml_path is not None: + # if there is a config file provided use it + server.prepare_optimizer_arguments(pyproject_toml_path) + else: + # otherwise look for it + pyproject_toml_path = _find_pyproject_toml(params.root_path_abs) + server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info") + if pyproject_toml_path: + server.prepare_optimizer_arguments(pyproject_toml_path) + else: + return { + "status": "error", + "message": "No pyproject.toml found in workspace.", + } # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth + + if getattr(params, "skip_validation", False): + return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path} + server.show_message_log("Validating project...", "Info") - config = is_valid_pyproject_toml(server.args.config_file) + config = is_valid_pyproject_toml(pyproject_toml_path) if config is None: server.show_message_log("pyproject.toml is not valid", "Error") return { "status": "error", - "message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions + "message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions, } args = process_args(server) @@ -183,7 +234,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat except Exception: return {"status": "error", "message": "Repository has no commits (unborn HEAD)"} - return {"status": "success", "moduleRoot": args.module_root} + return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path} def _initialize_optimizer_if_api_key_is_valid( @@ -328,7 +379,7 @@ def perform_function_optimization( # noqa: PLR0911 devnull_writer = open(os.devnull, "w") # noqa with contextlib.redirect_stdout(devnull_writer): - function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) + function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) function_optimizer.function_to_tests = function_to_tests test_setup_result = function_optimizer.generate_and_instrument_tests( diff --git a/codeflash/lsp/lsp_logger.py b/codeflash/lsp/lsp_logger.py index 11a30cb1e..711431c19 100644 --- a/codeflash/lsp/lsp_logger.py +++ b/codeflash/lsp/lsp_logger.py @@ -124,7 +124,7 @@ def setup_logging() -> logging.Logger: logger = logging.getLogger() logger.handlers.clear() - # Set up stderr handler for VS Code output channel with [LSP-Server] prefix + # Set up stderr handler for VS Code output channel handler = logging.StreamHandler(sys.stderr) handler.setLevel(logging.DEBUG) diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py index 06a4d4e34..daaa4efe5 100644 --- a/codeflash/lsp/server.py +++ b/codeflash/lsp/server.py @@ -1,15 +1,13 @@ from __future__ import annotations -from pathlib import Path from typing import TYPE_CHECKING, Any -from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType -from pygls import uris -from pygls.protocol import LanguageServerProtocol, lsp_method +from lsprotocol.types import LogMessageParams, MessageType +from pygls.protocol import LanguageServerProtocol from pygls.server import LanguageServer if TYPE_CHECKING: - from lsprotocol.types import InitializeParams, InitializeResult + from pathlib import Path from codeflash.optimization.optimizer import Optimizer @@ -17,30 +15,6 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol): _server: CodeflashLanguageServer - @lsp_method(INITIALIZE) - def lsp_initialize(self, params: InitializeParams) -> InitializeResult: - server = self._server - initialize_result: InitializeResult = super().lsp_initialize(params) - - workspace_uri = params.root_uri - if workspace_uri: - workspace_path = uris.to_fs_path(workspace_uri) - pyproject_toml_path = self._find_pyproject_toml(workspace_path) - if pyproject_toml_path: - server.prepare_optimizer_arguments(pyproject_toml_path) - else: - server.show_message("No pyproject.toml found in workspace.") - else: - server.show_message("No workspace URI provided.") - - return initialize_result - - def _find_pyproject_toml(self, workspace_path: str) -> Path | None: - workspace_path_obj = Path(workspace_path) - for file_path in workspace_path_obj.rglob("pyproject.toml"): - return file_path.resolve() - return None - class CodeflashLanguageServer(LanguageServer): def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401