Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_codeflash_api_key() -> str:
if env_api_key and not shell_api_key:
try:
from codeflash.either import is_successful

result = save_api_key_to_rc(env_api_key)
if is_successful(result):
logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}")
Expand Down
67 changes: 59 additions & 8 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_functions_within_git_diff,
)
from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
from codeflash.lsp.server import CodeflashLanguageServer

if TYPE_CHECKING:
from argparse import Namespace
Expand All @@ -50,6 +50,13 @@ class ProvideApiKeyParams:
api_key: str


@dataclass
class ValidateProjectParams:
root_path_abs: str
config_file: Optional[str] = None
skip_validation: bool = False


@dataclass
class OnPatchAppliedParams:
patch_id: str
Expand All @@ -60,7 +67,8 @@ class OptimizableFunctionsInCommitParams:
commit_hash: str


server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")


@server.feature("getOptimizableFunctionsInCurrentDiff")
Expand Down Expand Up @@ -160,17 +168,60 @@ def initialize_function_optimization(
return {"functionName": params.functionName, "status": "success"}


@server.feature("validateProject")
def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]:
def _find_pyproject_toml(workspace_path: str) -> Path | None:
workspace_path_obj = Path(workspace_path)
max_depth = 2
base_depth = len(workspace_path_obj.parts)

for root, dirs, files in os.walk(workspace_path_obj):
depth = len(Path(root).parts) - base_depth
if depth > max_depth:
# stop going deeper into this branch
dirs.clear()
continue

if "pyproject.toml" in files:
file_path = Path(root) / "pyproject.toml"
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
for line in f:
if line.strip() == "[tool.codeflash]":
return file_path.resolve()
return None


# should be called the first thing to initialize and validate the project
@server.feature("initProject")
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml

pyproject_toml_path: Path | None = getattr(params, "config_file", None)

if server.args is None:
if pyproject_toml_path is not None:
# if there is a config file provided use it
server.prepare_optimizer_arguments(pyproject_toml_path)
else:
# otherwise look for it
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
if pyproject_toml_path:
server.prepare_optimizer_arguments(pyproject_toml_path)
else:
return {
"status": "error",
"message": "No pyproject.toml found in workspace.",
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth

if getattr(params, "skip_validation", False):
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}

server.show_message_log("Validating project...", "Info")
config = is_valid_pyproject_toml(server.args.config_file)
config = is_valid_pyproject_toml(pyproject_toml_path)
if config is None:
server.show_message_log("pyproject.toml is not valid", "Error")
return {
"status": "error",
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
}

args = process_args(server)
Expand All @@ -183,7 +234,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
except Exception:
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}

return {"status": "success", "moduleRoot": args.module_root}
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}


def _initialize_optimizer_if_api_key_is_valid(
Expand Down Expand Up @@ -328,7 +379,7 @@ def perform_function_optimization( # noqa: PLR0911

devnull_writer = open(os.devnull, "w") # noqa
with contextlib.redirect_stdout(devnull_writer):
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
function_optimizer.function_to_tests = function_to_tests

test_setup_result = function_optimizer.generate_and_instrument_tests(
Expand Down
2 changes: 1 addition & 1 deletion codeflash/lsp/lsp_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def setup_logging() -> logging.Logger:
logger = logging.getLogger()
logger.handlers.clear()

# Set up stderr handler for VS Code output channel with [LSP-Server] prefix
# Set up stderr handler for VS Code output channel
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.DEBUG)

Expand Down
32 changes: 3 additions & 29 deletions codeflash/lsp/server.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,20 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
from pygls import uris
from pygls.protocol import LanguageServerProtocol, lsp_method
from lsprotocol.types import LogMessageParams, MessageType
from pygls.protocol import LanguageServerProtocol
from pygls.server import LanguageServer

if TYPE_CHECKING:
from lsprotocol.types import InitializeParams, InitializeResult
from pathlib import Path

from codeflash.optimization.optimizer import Optimizer


class CodeflashLanguageServerProtocol(LanguageServerProtocol):
_server: CodeflashLanguageServer

@lsp_method(INITIALIZE)
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
server = self._server
initialize_result: InitializeResult = super().lsp_initialize(params)

workspace_uri = params.root_uri
if workspace_uri:
workspace_path = uris.to_fs_path(workspace_uri)
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
if pyproject_toml_path:
server.prepare_optimizer_arguments(pyproject_toml_path)
else:
server.show_message("No pyproject.toml found in workspace.")
else:
server.show_message("No workspace URI provided.")

return initialize_result

def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
workspace_path_obj = Path(workspace_path)
for file_path in workspace_path_obj.rglob("pyproject.toml"):
return file_path.resolve()
return None


class CodeflashLanguageServer(LanguageServer):
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
Expand Down
Loading