Skip to content

Commit 3a30d9e

Browse files
fixes for the api key and the lsp gracefull shutdown
1 parent 231352f commit 3a30d9e

File tree

6 files changed

+60
-66
lines changed

6 files changed

+60
-66
lines changed

codeflash/api/cfapi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def make_cfapi_request(
4040
payload: dict[str, Any] | None = None,
4141
extra_headers: dict[str, str] | None = None,
4242
*,
43+
api_key: str | None = None,
4344
suppress_errors: bool = False,
4445
) -> Response:
4546
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
@@ -51,7 +52,7 @@ def make_cfapi_request(
5152
:return: The response object from the API.
5253
"""
5354
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
54-
cfapi_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
55+
cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"}
5556
if extra_headers:
5657
cfapi_headers.update(extra_headers)
5758
try:
@@ -83,15 +84,17 @@ def make_cfapi_request(
8384

8485

8586
@lru_cache(maxsize=1)
86-
def get_user_id() -> Optional[str]:
87+
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
8788
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
8889
8990
:return: The userid or None if the request fails.
9091
"""
9192
if not ensure_codeflash_api_key():
9293
return None
9394

94-
response = make_cfapi_request(endpoint="/cli-get-user", method="GET", extra_headers={"cli_version": __version__})
95+
response = make_cfapi_request(
96+
endpoint="/cli-get-user", method="GET", extra_headers={"cli_version": __version__}, api_key=api_key
97+
)
9598
if response.status_code == 200:
9699
if "min_version" not in response.text:
97100
return response.text

codeflash/lsp/beta.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from dataclasses import dataclass
66
from pathlib import Path
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Optional
88

99
import git
1010
from pygls import uris
@@ -104,6 +104,8 @@ def get_optimizable_functions(
104104
) -> dict[str, list[str]]:
105105
file_path = Path(uris.to_fs_path(params.textDocument.uri))
106106
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
107+
if not server.optimizer:
108+
return {"status": "error", "message": "optimizer not initialized"}
107109

108110
server.optimizer.args.file = file_path
109111
server.optimizer.args.function = None # Always get ALL functions, not just one
@@ -184,8 +186,10 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
184186
return {"status": "success", "moduleRoot": args.module_root}
185187

186188

187-
def _initialize_optimizer_if_api_key_is_valid(server: CodeflashLanguageServer) -> dict[str, str]:
188-
user_id = get_user_id()
189+
def _initialize_optimizer_if_api_key_is_valid(
190+
server: CodeflashLanguageServer, api_key: Optional[str] = None
191+
) -> dict[str, str]:
192+
user_id = get_user_id(api_key=api_key)
189193
if user_id is None:
190194
return {"status": "error", "message": "api key not found or invalid"}
191195

@@ -224,19 +228,19 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
224228
if not api_key.startswith("cf-"):
225229
return {"status": "error", "message": "Api key is not valid"}
226230

227-
result = save_api_key_to_rc(api_key)
228-
if not is_successful(result):
229-
return {"status": "error", "message": result.failure()}
230-
231231
# clear cache to ensure the new api key is used
232232
get_codeflash_api_key.cache_clear()
233233
get_user_id.cache_clear()
234234

235-
init_result = _initialize_optimizer_if_api_key_is_valid(server)
235+
init_result = _initialize_optimizer_if_api_key_is_valid(server, api_key)
236236
if init_result["status"] == "error":
237237
return {"status": "error", "message": "Api key is not valid"}
238238

239-
return {"status": "success", "message": "Api key saved successfully", "user_id": init_result["user_id"]}
239+
user_id = init_result["user_id"]
240+
result = save_api_key_to_rc(api_key)
241+
if not is_successful(result):
242+
return {"status": "error", "message": result.failure()}
243+
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
240244
except Exception:
241245
return {"status": "error", "message": "something went wrong while saving the api key"}
242246

codeflash/lsp/lsp_logger.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
22

3+
import logging
4+
import sys
35
from dataclasses import dataclass
46
from typing import Any, Callable, Optional
57

68
from codeflash.lsp.helpers import is_LSP_enabled
79
from codeflash.lsp.lsp_message import LspTextMessage
810

11+
root_logger = None
12+
913

1014
@dataclass
1115
class LspMessageTags:
@@ -109,3 +113,27 @@ def enhanced_log(
109113
clean_msg = LspTextMessage(text=clean_msg, takes_time=final_tags.loading).serialize()
110114

111115
actual_log_fn(clean_msg, *args, **kwargs)
116+
117+
118+
# Configure logging to stderr for VS Code output channel
119+
def setup_logging() -> logging.Logger:
120+
global root_logger # noqa: PLW0603
121+
if root_logger:
122+
return root_logger
123+
# Clear any existing handlers to prevent conflicts
124+
logger = logging.getLogger()
125+
logger.handlers.clear()
126+
127+
# Set up stderr handler for VS Code output channel with [LSP-Server] prefix
128+
handler = logging.StreamHandler(sys.stderr)
129+
handler.setLevel(logging.DEBUG)
130+
131+
# Configure root logger
132+
logger.addHandler(handler)
133+
134+
# Also configure the pygls logger specifically
135+
pygls_logger = logging.getLogger("pygls")
136+
pygls_logger.setLevel(logging.INFO)
137+
138+
root_logger = logger
139+
return logger

codeflash/lsp/server.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22

3-
import sys
43
from pathlib import Path
5-
from threading import Event
6-
from typing import TYPE_CHECKING, Any, Optional, TextIO
4+
from typing import TYPE_CHECKING, Any
75

86
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
97
from pygls import uris
108
from pygls.protocol import LanguageServerProtocol, lsp_method
11-
from pygls.server import LanguageServer, StdOutTransportAdapter, aio_readline
9+
from pygls.server import LanguageServer
1210

1311
if TYPE_CHECKING:
1412
from lsprotocol.types import InitializeParams, InitializeResult
@@ -85,6 +83,8 @@ def show_message_log(self, message: str, message_type: str) -> None:
8583
self.lsp.notify("window/logMessage", log_params)
8684

8785
def cleanup_the_optimizer(self) -> None:
86+
if not self.optimizer:
87+
return
8888
try:
8989
self.optimizer.cleanup_temporary_paths()
9090
# restore args and test cfg
@@ -96,26 +96,7 @@ def cleanup_the_optimizer(self) -> None:
9696
except Exception:
9797
self.show_message_log("Failed to cleanup optimizer", "Error")
9898

99-
def start_io(self, stdin: Optional[TextIO] = None, stdout: Optional[TextIO] = None) -> None:
100-
self.show_message_log("Starting IO server", "Info")
101-
102-
self._stop_event = Event()
103-
transport = StdOutTransportAdapter(stdin or sys.stdin.buffer, stdout or sys.stdout.buffer)
104-
self.lsp.connection_made(transport)
105-
try:
106-
self.loop.run_until_complete(
107-
aio_readline(
108-
self.loop,
109-
self.thread_pool_executor,
110-
self._stop_event,
111-
stdin or sys.stdin.buffer,
112-
self.lsp.data_received,
113-
)
114-
)
115-
except BrokenPipeError:
116-
self.show_message_log("Connection to the client is lost! Shutting down the server.", "Error")
117-
except (KeyboardInterrupt, SystemExit):
118-
pass
119-
finally:
120-
self.cleanup_the_optimizer()
121-
self.shutdown()
99+
def shutdown(self) -> None:
100+
"""Gracefully shutdown the server."""
101+
self.cleanup_the_optimizer()
102+
super().shutdown()

codeflash/lsp/server_entry.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,13 @@
77
executed directly by users.
88
"""
99

10-
import logging
11-
import sys
12-
1310
from codeflash.lsp.beta import server
14-
15-
16-
# Configure logging to stderr for VS Code output channel
17-
def setup_logging() -> logging.Logger:
18-
# Clear any existing handlers to prevent conflicts
19-
root_logger = logging.getLogger()
20-
root_logger.handlers.clear()
21-
22-
# Set up stderr handler for VS Code output channel with [LSP-Server] prefix
23-
handler = logging.StreamHandler(sys.stderr)
24-
handler.setLevel(logging.DEBUG)
25-
26-
# Configure root logger
27-
root_logger.addHandler(handler)
28-
29-
# Also configure the pygls logger specifically
30-
pygls_logger = logging.getLogger("pygls")
31-
pygls_logger.setLevel(logging.INFO)
32-
33-
return root_logger
34-
11+
from codeflash.lsp.lsp_logger import setup_logging
3512

3613
if __name__ == "__main__":
3714
# Set up logging
38-
log = setup_logging()
39-
log.info("Starting Codeflash Language Server...")
15+
root_logger = setup_logging()
16+
root_logger.info("Starting Codeflash Language Server...")
4017

4118
# Start the language server
4219
server.start_io()

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def _process_line_profiler_results(self) -> OptimizedCandidate | None:
164164

165165
def _process_refinement_results(self) -> OptimizedCandidate | None:
166166
"""Process refinement results and add to queue."""
167-
logger.info("loading|Refining generated code for improved quality and performance...")
167+
if self.future_all_refinements:
168+
logger.info("loading|Refining generated code for improved quality and performance...")
168169
concurrent.futures.wait(self.future_all_refinements)
169170
refinement_response = []
170171

0 commit comments

Comments
 (0)