Skip to content

Commit 4eaf0c4

Browse files
message id tags
1 parent de9837a commit 4eaf0c4

File tree

5 files changed

+53
-27
lines changed

5 files changed

+53
-27
lines changed

codeflash/cli_cmds/console.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,16 @@ def paneled_text(
8080
console.print(panel)
8181

8282

83-
def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None:
83+
def code_print(
84+
code_str: str,
85+
file_name: Optional[str] = None,
86+
function_name: Optional[str] = None,
87+
lsp_message_id: Optional[str] = None,
88+
) -> None:
8489
if is_LSP_enabled():
85-
lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name))
90+
lsp_log(
91+
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
92+
)
8693
return
8794
"""Print code with syntax highlighting."""
8895
from rich.syntax import Syntax

codeflash/lsp/beta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def cleanup_optimizer(_params: any) -> dict[str, str]:
353353

354354
@server.feature("initializeFunctionOptimization")
355355
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
356-
with execution_context(task_id=params.task_id):
356+
with execution_context(task_id=getattr(params, "task_id", None)):
357357
document_uri = params.textDocument.uri
358358
document = server.workspace.get_text_document(document_uri)
359359
file_path = Path(document.path)
@@ -423,7 +423,7 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) ->
423423

424424
@server.feature("performFunctionOptimization")
425425
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
426-
with execution_context(task_id=params.task_id):
426+
with execution_context(task_id=getattr(params, "task_id", None)):
427427
loop = asyncio.get_running_loop()
428428
cancel_event = threading.Event()
429429

codeflash/lsp/lsp_logger.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import logging
44
import sys
55
from dataclasses import dataclass
6-
from typing import Any, Callable
6+
from typing import Any, Callable, Optional
77

88
from codeflash.lsp.helpers import is_LSP_enabled
9-
from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter
9+
from codeflash.lsp.lsp_message import LSPMessageId, LspTextMessage, message_delimiter
1010

1111
root_logger = None
1212

13+
message_id_prefix = "id:"
14+
1315

1416
@dataclass
1517
class LspMessageTags:
@@ -18,6 +20,7 @@ class LspMessageTags:
1820
lsp: bool = False # lsp (lsp only)
1921
force_lsp: bool = False # force_lsp (you can use this to force a message to be sent to the LSP even if the level is not supported)
2022
loading: bool = False # loading (you can use this to indicate that the message is a loading message)
23+
message_id: Optional[LSPMessageId] = None # example: id:best_candidate
2124
highlight: bool = False # highlight (you can use this to highlight the message by wrapping it in ``)
2225
h1: bool = False # h1
2326
h2: bool = False # h2
@@ -52,24 +55,27 @@ def extract_tags(msg: str) -> tuple[LspMessageTags, str]:
5255
tags = {tag.strip() for tag in tags_str.split(",")}
5356
message_tags = LspMessageTags()
5457
# manually check and set to avoid repeated membership tests
55-
if "lsp" in tags:
56-
message_tags.lsp = True
57-
if "!lsp" in tags:
58-
message_tags.not_lsp = True
59-
if "force_lsp" in tags:
60-
message_tags.force_lsp = True
61-
if "loading" in tags:
62-
message_tags.loading = True
63-
if "highlight" in tags:
64-
message_tags.highlight = True
65-
if "h1" in tags:
66-
message_tags.h1 = True
67-
if "h2" in tags:
68-
message_tags.h2 = True
69-
if "h3" in tags:
70-
message_tags.h3 = True
71-
if "h4" in tags:
72-
message_tags.h4 = True
58+
for tag in tags:
59+
if tag.startswith(message_id_prefix):
60+
message_tags.message_id = LSPMessageId(tag[len(message_id_prefix) :]).value
61+
elif tag == "lsp":
62+
message_tags.lsp = True
63+
elif tag == "!lsp":
64+
message_tags.not_lsp = True
65+
elif tag == "force_lsp":
66+
message_tags.force_lsp = True
67+
elif tag == "loading":
68+
message_tags.loading = True
69+
elif tag == "highlight":
70+
message_tags.highlight = True
71+
elif tag == "h1":
72+
message_tags.h1 = True
73+
elif tag == "h2":
74+
message_tags.h2 = True
75+
elif tag == "h3":
76+
message_tags.h3 = True
77+
elif tag == "h4":
78+
message_tags.h4 = True
7379
return message_tags, content
7480

7581
return LspMessageTags(), msg
@@ -114,7 +120,7 @@ def enhanced_log(
114120
if is_normal_text_message:
115121
clean_msg = add_heading_tags(clean_msg, tags)
116122
clean_msg = add_highlight_tags(clean_msg, tags)
117-
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize()
123+
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading, message_id=tags.message_id).serialize()
118124

119125
actual_log_fn(clean_msg, *args, **kwargs)
120126

codeflash/lsp/lsp_message.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import enum
34
import json
45
from dataclasses import asdict, dataclass
56
from pathlib import Path
@@ -14,10 +15,17 @@
1415
message_delimiter = "\u241f"
1516

1617

18+
# allow the client to know which message it is receiving
19+
class LSPMessageId(enum.Enum):
20+
BEST_CANDIDATE = "best_candidate"
21+
CANDIDATE = "candidate"
22+
23+
1724
@dataclass
1825
class LspMessage:
1926
# to show a loading indicator if the operation is taking time like generating candidates or tests
2027
takes_time: bool = False
28+
message_id: Optional[str] = None
2129

2230
def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
2331
if isinstance(obj, list):

codeflash/optimization/function_optimizer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
6767
from codeflash.either import Failure, Success, is_successful
6868
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
69-
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage
69+
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
7070
from codeflash.models.ExperimentMetadata import ExperimentMetadata
7171
from codeflash.models.models import (
7272
BestOptimization,
@@ -510,7 +510,11 @@ def determine_best_candidate(
510510
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
511511
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
512512
logger.info(f"h3|Optimization candidate {candidate_index}/{processor.candidate_len}:")
513-
code_print(candidate.source_code.flat, file_name=f"candidate_{candidate_index}.py")
513+
code_print(
514+
candidate.source_code.flat,
515+
file_name=f"candidate_{candidate_index}.py",
516+
lsp_message_id=LSPMessageId.CANDIDATE.value,
517+
)
514518
# map ast normalized code to diff len, unnormalized code
515519
# map opt id to the shortest unnormalized code
516520
try:
@@ -1291,6 +1295,7 @@ def find_and_process_best_optimization(
12911295
best_optimization.candidate.source_code.flat,
12921296
file_name="best_candidate.py",
12931297
function_name=self.function_to_optimize.function_name,
1298+
lsp_message_id=LSPMessageId.BEST_CANDIDATE.value,
12941299
)
12951300
processed_benchmark_info = None
12961301
if self.args.benchmark:

0 commit comments

Comments
 (0)