Skip to content

Commit 90cdf3e

Browse files
move searching for pyproject file from the server initialization to the validation feature
1 parent d3e6427 commit 90cdf3e

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

codeflash/lsp/beta.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ class ProvideApiKeyParams:
5050
api_key: str
5151

5252

53+
@dataclass
54+
class ValidateProjectParams:
55+
root_path_abs: str
56+
config_file: Optional[str] = None
57+
skip_validation: bool = False
58+
59+
5360
@dataclass
5461
class OnPatchAppliedParams:
5562
patch_id: str
@@ -160,17 +167,60 @@ def initialize_function_optimization(
160167
return {"functionName": params.functionName, "status": "success"}
161168

162169

163-
@server.feature("validateProject")
164-
def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]:
170+
def _find_pyproject_toml(workspace_path: str) -> Path | None:
171+
workspace_path_obj = Path(workspace_path)
172+
max_depth = 2
173+
base_depth = len(workspace_path_obj.parts)
174+
175+
for root, dirs, files in os.walk(workspace_path_obj):
176+
depth = len(Path(root).parts) - base_depth
177+
if depth > max_depth:
178+
# stop going deeper into this branch
179+
dirs.clear()
180+
continue
181+
182+
if "pyproject.toml" in files:
183+
file_path = Path(root) / "pyproject.toml"
184+
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
185+
for line in f:
186+
if line.strip() == "[tool.codeflash]":
187+
return file_path.resolve()
188+
return None
189+
190+
191+
# should be called the first thing to initialize and validate the project
192+
@server.feature("initProject")
193+
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
165194
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
166195

196+
pyproject_toml_path: Path | None = getattr(params, "config_file", None)
197+
198+
if server.args is None:
199+
if pyproject_toml_path is not None:
200+
# if there is a config file provided use it
201+
server.prepare_optimizer_arguments(pyproject_toml_path)
202+
else:
203+
# otherwise look for it
204+
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
205+
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
206+
if pyproject_toml_path:
207+
server.prepare_optimizer_arguments(pyproject_toml_path)
208+
else:
209+
return {
210+
"status": "error",
211+
"message": "No pyproject.toml found in workspace.",
212+
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
213+
214+
if getattr(params, "skip_validation", False):
215+
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}
216+
167217
server.show_message_log("Validating project...", "Info")
168-
config = is_valid_pyproject_toml(server.args.config_file)
218+
config = is_valid_pyproject_toml(pyproject_toml_path)
169219
if config is None:
170220
server.show_message_log("pyproject.toml is not valid", "Error")
171221
return {
172222
"status": "error",
173-
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
223+
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
174224
}
175225

176226
args = process_args(server)
@@ -183,7 +233,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
183233
except Exception:
184234
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
185235

186-
return {"status": "success", "moduleRoot": args.module_root}
236+
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}
187237

188238

189239
def _initialize_optimizer_if_api_key_is_valid(

codeflash/lsp/server.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22

3-
from pathlib import Path
43
from typing import TYPE_CHECKING, Any
54

65
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
7-
from pygls import uris
86
from pygls.protocol import LanguageServerProtocol, lsp_method
97
from pygls.server import LanguageServer
108

119
if TYPE_CHECKING:
10+
from pathlib import Path
11+
1212
from lsprotocol.types import InitializeParams, InitializeResult
1313

1414
from codeflash.optimization.optimizer import Optimizer
@@ -19,28 +19,9 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol):
1919

2020
@lsp_method(INITIALIZE)
2121
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
22-
server = self._server
2322
initialize_result: InitializeResult = super().lsp_initialize(params)
24-
25-
workspace_uri = params.root_uri
26-
if workspace_uri:
27-
workspace_path = uris.to_fs_path(workspace_uri)
28-
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
29-
if pyproject_toml_path:
30-
server.prepare_optimizer_arguments(pyproject_toml_path)
31-
else:
32-
server.show_message("No pyproject.toml found in workspace.")
33-
else:
34-
server.show_message("No workspace URI provided.")
35-
3623
return initialize_result
3724

38-
def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
39-
workspace_path_obj = Path(workspace_path)
40-
for file_path in workspace_path_obj.rglob("pyproject.toml"):
41-
return file_path.resolve()
42-
return None
43-
4425

4526
class CodeflashLanguageServer(LanguageServer):
4627
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401

0 commit comments

Comments
 (0)