Skip to content

Commit bae3504

Browse files
lsp messages
1 parent 23c75fe commit bae3504

File tree

4 files changed

+79
-8
lines changed

4 files changed

+79
-8
lines changed

codeflash/lsp/beta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
194194
except Exception:
195195
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
196196

197-
return {"status": "success", "module_root": args.module_root}
197+
return {"status": "success", "moduleRoot": args.module_root}
198198

199199

200200
def _initialize_optimizer_if_api_key_is_valid(server: CodeflashLanguageServer) -> dict[str, str]:

codeflash/lsp/helpers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import lru_cache
33
from typing import Any, Callable
44

5+
skip_lsp_log_prefix = "!lsp:"
6+
57

68
@lru_cache(maxsize=1)
79
def is_LSP_enabled() -> bool:
@@ -10,18 +12,19 @@ def is_LSP_enabled() -> bool:
1012

1113
def enhanced_log(msg: str, actual_log_fn: Callable[[str, Any, Any], None], *args: Any, **kwargs: Any) -> None: # noqa: ANN401
1214
lsp_enabled = is_LSP_enabled()
15+
str_msg = isinstance(msg, str)
16+
skip_lsp_log = str_msg and msg.strip().startswith(skip_lsp_log_prefix)
17+
18+
if skip_lsp_log:
19+
msg = msg[len(skip_lsp_log_prefix) :]
1320

14-
# normal cli moded
21+
# normal cli mode
1522
if not lsp_enabled:
1623
actual_log_fn(msg, *args, **kwargs)
1724
return
1825

1926
#### LSP mode ####
20-
if type(msg) != str: # noqa: E721
21-
return
22-
23-
if msg.startswith("Nonzero return code"):
24-
# skip logging the failed tests msg to the client
27+
if skip_lsp_log or not str_msg:
2528
return
2629

2730
actual_log_fn(f"::::{msg}", *args, **kwargs)

codeflash/lsp/lsp_message.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from dataclasses import asdict, dataclass, field
5+
from pathlib import Path
6+
from typing import Any, Optional
7+
8+
json_primitive_types = (str, float, int, bool)
9+
10+
11+
@dataclass
12+
class LspMessage:
13+
takes_time: bool = field(
14+
default=False, kw_only=True
15+
) # to show a loading indicator if the operation is taking time like generating candidates or tests
16+
17+
def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
18+
if isinstance(obj, list):
19+
return [self._loop_through(i) for i in obj]
20+
if isinstance(obj, dict):
21+
return {k: self._loop_through(v) for k, v in obj.items()}
22+
if isinstance(obj, json_primitive_types) or obj is None:
23+
return obj
24+
if isinstance(obj, Path):
25+
return obj.as_posix()
26+
return str(obj)
27+
28+
def type(self) -> str:
29+
raise NotImplementedError
30+
31+
def serialize(self) -> str:
32+
data = asdict(self)
33+
data["type"] = self.type()
34+
return json.dumps(data)
35+
36+
37+
@dataclass
38+
class LspTextMessage(LspMessage):
39+
text: str
40+
41+
def type(self) -> str:
42+
return "text"
43+
44+
45+
@dataclass
46+
class LspCodeMessage(LspMessage):
47+
code: str
48+
path: Optional[Path] = None
49+
function_name: Optional[str] = None
50+
51+
def type(self) -> str:
52+
return "code"
53+
54+
55+
@dataclass
56+
class LspMarkdownMessage(LspMessage):
57+
markdown: str
58+
59+
def type(self) -> str:
60+
return "markdown"
61+
62+
63+
@dataclass
64+
class LspStatsMessage(LspMessage):
65+
stats: dict[str, Any]
66+
67+
def type(self) -> str:
68+
return "stats"

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1673,7 +1673,7 @@ def run_and_parse_tests(
16731673
return TestResults(), None
16741674
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
16751675
logger.debug(
1676-
f"Nonzero return code {run_result.returncode} when running tests in "
1676+
f"!lsp:Nonzero return code {run_result.returncode} when running tests in "
16771677
f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n"
16781678
f"stdout: {run_result.stdout}\n"
16791679
f"stderr: {run_result.stderr}\n"

0 commit comments

Comments
 (0)