Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions codeflash/cli_cmds/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,16 @@ def paneled_text(
console.print(panel)


def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None:
def code_print(
code_str: str,
file_name: Optional[str] = None,
function_name: Optional[str] = None,
lsp_message_id: Optional[str] = None,
) -> None:
if is_LSP_enabled():
lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name))
lsp_log(
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
)
return
"""Print code with syntax highlighting."""
from rich.syntax import Syntax
Expand Down
4 changes: 2 additions & 2 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def cleanup_optimizer(_params: any) -> dict[str, str]:

@server.feature("initializeFunctionOptimization")
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
with execution_context(task_id=params.task_id):
with execution_context(task_id=getattr(params, "task_id", None)):
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)
file_path = Path(document.path)
Expand Down Expand Up @@ -423,7 +423,7 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) ->

@server.feature("performFunctionOptimization")
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
with execution_context(task_id=params.task_id):
with execution_context(task_id=getattr(params, "task_id", None)):
loop = asyncio.get_running_loop()
cancel_event = threading.Event()

Expand Down
48 changes: 27 additions & 21 deletions codeflash/lsp/lsp_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import logging
import sys
from dataclasses import dataclass
from typing import Any, Callable
from typing import Any, Callable, Optional

from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter
from codeflash.lsp.lsp_message import LSPMessageId, LspTextMessage, message_delimiter

root_logger = None

message_id_prefix = "id:"


@dataclass
class LspMessageTags:
Expand All @@ -18,6 +20,7 @@ class LspMessageTags:
lsp: bool = False # lsp (lsp only)
force_lsp: bool = False # force_lsp (you can use this to force a message to be sent to the LSP even if the level is not supported)
loading: bool = False # loading (you can use this to indicate that the message is a loading message)
message_id: Optional[LSPMessageId] = None # example: id:best_candidate
highlight: bool = False # highlight (you can use this to highlight the message by wrapping it in ``)
h1: bool = False # h1
h2: bool = False # h2
Expand Down Expand Up @@ -52,24 +55,27 @@ def extract_tags(msg: str) -> tuple[LspMessageTags, str]:
tags = {tag.strip() for tag in tags_str.split(",")}
message_tags = LspMessageTags()
# manually check and set to avoid repeated membership tests
if "lsp" in tags:
message_tags.lsp = True
if "!lsp" in tags:
message_tags.not_lsp = True
if "force_lsp" in tags:
message_tags.force_lsp = True
if "loading" in tags:
message_tags.loading = True
if "highlight" in tags:
message_tags.highlight = True
if "h1" in tags:
message_tags.h1 = True
if "h2" in tags:
message_tags.h2 = True
if "h3" in tags:
message_tags.h3 = True
if "h4" in tags:
message_tags.h4 = True
for tag in tags:
if tag.startswith(message_id_prefix):
message_tags.message_id = LSPMessageId(tag[len(message_id_prefix) :]).value
elif tag == "lsp":
message_tags.lsp = True
elif tag == "!lsp":
message_tags.not_lsp = True
elif tag == "force_lsp":
message_tags.force_lsp = True
elif tag == "loading":
message_tags.loading = True
elif tag == "highlight":
message_tags.highlight = True
elif tag == "h1":
message_tags.h1 = True
elif tag == "h2":
message_tags.h2 = True
elif tag == "h3":
message_tags.h3 = True
elif tag == "h4":
message_tags.h4 = True
return message_tags, content

return LspMessageTags(), msg
Expand Down Expand Up @@ -118,7 +124,7 @@ def enhanced_log(
if is_normal_text_message:
clean_msg = add_heading_tags(clean_msg, tags)
clean_msg = add_highlight_tags(clean_msg, tags)
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize()
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading, message_id=tags.message_id).serialize()

actual_log_fn(clean_msg, *args, **kwargs)

Expand Down
8 changes: 8 additions & 0 deletions codeflash/lsp/lsp_message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import json
from dataclasses import asdict, dataclass
from pathlib import Path
Expand All @@ -14,10 +15,17 @@
message_delimiter = "\u241f"


# allow the client to know which message it is receiving
class LSPMessageId(enum.Enum):
BEST_CANDIDATE = "best_candidate"
CANDIDATE = "candidate"


@dataclass
class LspMessage:
# to show a loading indicator if the operation is taking time like generating candidates or tests
takes_time: bool = False
message_id: Optional[str] = None

def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
if isinstance(obj, list):
Expand Down
9 changes: 7 additions & 2 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
from codeflash.either import Failure, Success, is_successful
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
BestOptimization,
Expand Down Expand Up @@ -510,7 +510,11 @@ def determine_best_candidate(
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"h3|Optimization candidate {candidate_index}/{processor.candidate_len}:")
code_print(candidate.source_code.flat, file_name=f"candidate_{candidate_index}.py")
code_print(
candidate.source_code.flat,
file_name=f"candidate_{candidate_index}.py",
lsp_message_id=LSPMessageId.CANDIDATE.value,
)
# map ast normalized code to diff len, unnormalized code
# map opt id to the shortest unnormalized code
try:
Expand Down Expand Up @@ -1291,6 +1295,7 @@ def find_and_process_best_optimization(
best_optimization.candidate.source_code.flat,
file_name="best_candidate.py",
function_name=self.function_to_optimize.function_name,
lsp_message_id=LSPMessageId.BEST_CANDIDATE.value,
)
processed_benchmark_info = None
if self.args.benchmark:
Expand Down
Loading