Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions codeflash/api/cfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def get_user_id() -> Optional[str]:
if min_version and version.parse(min_version) > version.parse(__version__):
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
console.print(f"[bold red]{msg}[/bold red]")
if console.quiet: # lsp
logger.debug(msg)
return f"Error: {msg}"
sys.exit(1)
return userid

Expand Down
7 changes: 6 additions & 1 deletion codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =

@lru_cache(maxsize=1)
def get_codeflash_api_key() -> str:
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
if console.quiet: # lsp
# prefer shell config over env var in lsp mode
api_key = read_api_key_from_shell_config()
else:
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()

api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
if not api_key:
msg = (
Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/shell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_shell_rc_path() -> Path:


def get_api_key_export_line(api_key: str) -> str:
return f"{SHELL_RC_EXPORT_PREFIX}{api_key}"
return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"'


def save_api_key_to_rc(api_key: str) -> Result[str, str]:
Expand Down
54 changes: 54 additions & 0 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from pygls import uris

from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol

Expand All @@ -28,6 +30,11 @@ class FunctionOptimizationParams:
functionName: str # noqa: N815


@dataclass
class ProvideApiKeyParams:
api_key: str


server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)


Expand Down Expand Up @@ -118,6 +125,53 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt
return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests}


def _initialize_optimizer_if_valid(server: CodeflashLanguageServer) -> dict[str, str]:
user_id = get_user_id()
if user_id is None:
return {"status": "error", "message": "api key not found or invalid"}

if user_id.startswith("Error: "):
error_msg = user_id[7:]
return {"status": "error", "message": error_msg}

from codeflash.optimization.optimizer import Optimizer

server.optimizer = Optimizer(server.args)
return {"status": "success", "user_id": user_id}


@server.feature("apiKeyExistsAndValid")
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
try:
return _initialize_optimizer_if_valid(server)
except Exception:
return {"status": "error", "message": "something went wrong while validating the api key"}


@server.feature("provideApiKey")
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
try:
api_key = params.api_key
if not api_key.startswith("cf-"):
return {"status": "error", "message": "Api key is not valid"}

result = save_api_key_to_rc(api_key)
if not is_successful(result):
return {"status": "error", "message": result.failure()}

# clear cache to ensure the new api key is used
get_codeflash_api_key.cache_clear()
get_user_id.cache_clear()

init_result = _initialize_optimizer_if_valid(server)
if init_result["status"] == "error":
return {"status": "error", "message": "Api key is not valid"}

return {"status": "success", "message": "Api key saved successfully", "user_id": init_result["user_id"]}
except Exception:
return {"status": "error", "message": "something went wrong while saving the api key"}


@server.feature("prepareOptimization")
def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
current_function = server.optimizer.current_function_being_optimized
Expand Down
10 changes: 6 additions & 4 deletions codeflash/lsp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
workspace_path = uris.to_fs_path(workspace_uri)
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
if pyproject_toml_path:
server.initialize_optimizer(pyproject_toml_path)
server.prepare_optimizer_arguments(pyproject_toml_path)
server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}")
else:
server.show_message("No pyproject.toml found in workspace.")
Expand All @@ -45,16 +45,17 @@ class CodeflashLanguageServer(LanguageServer):
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(*args, **kwargs)
self.optimizer = None
self.args = None

def initialize_optimizer(self, config_file: Path) -> None:
def prepare_optimizer_arguments(self, config_file: Path) -> None:
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.optimization.optimizer import Optimizer

args = parse_args()
args.config_file = config_file
args.no_pr = True # LSP server should not create PRs
args = process_pyproject_config(args)
self.optimizer = Optimizer(args)
self.args = args
# avoid initializing the optimizer during initialization, because it can cause an error if the api key is invalid

def show_message_log(self, message: str, message_type: str) -> None:
"""Send a log message to the client's output channel.
Expand All @@ -70,6 +71,7 @@ def show_message_log(self, message: str, message_type: str) -> None:
"Warning": MessageType.Warning,
"Error": MessageType.Error,
"Log": MessageType.Log,
"Debug": MessageType.Debug,
}

lsp_message_type = type_mapping.get(message_type, MessageType.Info)
Expand Down
Loading