|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from pathlib import Path |
| 4 | +from typing import TYPE_CHECKING, Any |
| 5 | + |
| 6 | +from lsprotocol.types import INITIALIZE |
| 7 | +from pygls import uris |
| 8 | +from pygls.protocol import LanguageServerProtocol, lsp_method |
| 9 | +from pygls.server import LanguageServer |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + from lsprotocol.types import InitializeParams, InitializeResult |
| 13 | + |
| 14 | + |
| 15 | +class CodeflashLanguageServerProtocol(LanguageServerProtocol): |
| 16 | + _server: CodeflashLanguageServer |
| 17 | + |
| 18 | + @lsp_method(INITIALIZE) |
| 19 | + def lsp_initialize(self, params: InitializeParams) -> InitializeResult: |
| 20 | + server = self._server |
| 21 | + initialize_result: InitializeResult = super().lsp_initialize(params) |
| 22 | + |
| 23 | + workspace_uri = params.root_uri |
| 24 | + if workspace_uri: |
| 25 | + workspace_path = uris.to_fs_path(workspace_uri) |
| 26 | + pyproject_toml_path = self._find_pyproject_toml(workspace_path) |
| 27 | + if pyproject_toml_path: |
| 28 | + server.initialize_optimizer(pyproject_toml_path) |
| 29 | + server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}") |
| 30 | + else: |
| 31 | + server.show_message("No pyproject.toml found in workspace.") |
| 32 | + else: |
| 33 | + server.show_message("No workspace URI provided.") |
| 34 | + |
| 35 | + return initialize_result |
| 36 | + |
| 37 | + def _find_pyproject_toml(self, workspace_path: str) -> Path | None: |
| 38 | + workspace_path_obj = Path(workspace_path) |
| 39 | + for file_path in workspace_path_obj.rglob("pyproject.toml"): |
| 40 | + return file_path.resolve() |
| 41 | + return None |
| 42 | + |
| 43 | + |
| 44 | +class CodeflashLanguageServer(LanguageServer): |
| 45 | + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 |
| 46 | + super().__init__(*args, **kwargs) |
| 47 | + self.optimizer = None |
| 48 | + |
| 49 | + def initialize_optimizer(self, config_file: Path) -> None: |
| 50 | + from codeflash.cli_cmds.cli import parse_args, process_pyproject_config |
| 51 | + from codeflash.optimization.optimizer import Optimizer |
| 52 | + |
| 53 | + args = parse_args() |
| 54 | + args.config_file = config_file |
| 55 | + args = process_pyproject_config(args) |
| 56 | + self.optimizer = Optimizer(args) |
0 commit comments