diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 67b4bacd2..4f0215f12 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -2,7 +2,9 @@ import asyncio import contextlib +import contextvars import os +import threading from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -27,8 +29,8 @@ get_functions_within_git_diff, ) from codeflash.either import is_successful -from codeflash.lsp.features.perform_optimization import sync_perform_optimization -from codeflash.lsp.server import CodeflashLanguageServer +from codeflash.lsp.features.perform_optimization import get_cancelled_reponse, sync_perform_optimization +from codeflash.lsp.server import CodeflashServerSingleton if TYPE_CHECKING: from argparse import Namespace @@ -47,6 +49,7 @@ class OptimizableFunctionsParams: class FunctionOptimizationInitParams: textDocument: types.TextDocumentIdentifier # noqa: N815 functionName: str # noqa: N815 + task_id: str @dataclass @@ -84,30 +87,24 @@ class WriteConfigParams: config: any -server = CodeflashLanguageServer("codeflash-language-server", "v1.0") +server = CodeflashServerSingleton.get() @server.feature("getOptimizableFunctionsInCurrentDiff") -def get_functions_in_current_git_diff( - server: CodeflashLanguageServer, _params: OptimizableFunctionsParams -) -> dict[str, str | dict[str, list[str]]]: +def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]: functions = get_functions_within_git_diff(uncommitted_changes=True) - file_to_qualified_names = _group_functions_by_file(server, functions) + file_to_qualified_names = _group_functions_by_file(functions) return {"functions": file_to_qualified_names, "status": "success"} @server.feature("getOptimizableFunctionsInCommit") -def get_functions_in_commit( - server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams -) -> dict[str, str | dict[str, list[str]]]: +def get_functions_in_commit(params: OptimizableFunctionsInCommitParams) -> dict[str, str | dict[str, list[str]]]: functions = get_functions_inside_a_commit(params.commit_hash) - file_to_qualified_names = _group_functions_by_file(server, functions) + file_to_qualified_names = _group_functions_by_file(functions) return {"functions": file_to_qualified_names, "status": "success"} -def _group_functions_by_file( - server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]] -) -> dict[str, list[str]]: +def _group_functions_by_file(functions: dict[str, list[FunctionToOptimize]]) -> dict[str, list[str]]: file_to_funcs_to_optimize, _ = filter_functions( modified_functions=functions, tests_root=server.optimizer.test_cfg.tests_root, @@ -123,9 +120,7 @@ def _group_functions_by_file( @server.feature("getOptimizableFunctions") -def get_optimizable_functions( - server: CodeflashLanguageServer, params: OptimizableFunctionsParams -) -> dict[str, list[str]]: +def get_optimizable_functions(params: OptimizableFunctionsParams) -> dict[str, list[str]]: document_uri = params.textDocument.uri document = server.workspace.get_text_document(document_uri) @@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]: @server.feature("writeConfig") -def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]: +def write_config(params: WriteConfigParams) -> dict[str, any]: cfg = params.config cfg_file = Path(params.config_file) if params.config_file else None @@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> @server.feature("getConfigSuggestions") -def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]: +def get_config_suggestions(_params: any) -> dict[str, any]: module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root) tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root) test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework) @@ -212,7 +207,7 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di # should be called the first thing to initialize and validate the project @server.feature("initProject") -def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: +def init_project(params: ValidateProjectParams) -> dict[str, str]: # Always process args in the init project, the extension can call server.args_processed_before = False @@ -255,14 +250,12 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) "existingConfig": config, } - args = process_args(server) + args = process_args() return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root} -def _initialize_optimizer_if_api_key_is_valid( - server: CodeflashLanguageServer, api_key: Optional[str] = None -) -> dict[str, str]: +def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]: user_id = get_user_id(api_key=api_key) if user_id is None: return {"status": "error", "message": "api key not found or invalid"} @@ -273,12 +266,12 @@ def _initialize_optimizer_if_api_key_is_valid( from codeflash.optimization.optimizer import Optimizer - new_args = process_args(server) + new_args = process_args() server.optimizer = Optimizer(new_args) return {"status": "success", "user_id": user_id} -def process_args(server: CodeflashLanguageServer) -> Namespace: +def process_args() -> Namespace: if server.args_processed_before: return server.args new_args = process_pyproject_config(server.args) @@ -288,15 +281,15 @@ def process_args(server: CodeflashLanguageServer) -> Namespace: @server.feature("apiKeyExistsAndValid") -def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]: +def check_api_key(_params: any) -> dict[str, str]: try: - return _initialize_optimizer_if_api_key_is_valid(server) + return _initialize_optimizer_if_api_key_is_valid() except Exception: return {"status": "error", "message": "something went wrong while validating the api key"} @server.feature("provideApiKey") -def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]: +def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]: try: api_key = params.api_key if not api_key.startswith("cf-"): @@ -306,7 +299,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams get_codeflash_api_key.cache_clear() get_user_id.cache_clear() - init_result = _initialize_optimizer_if_api_key_is_valid(server, api_key) + init_result = _initialize_optimizer_if_api_key_is_valid(api_key) if init_result["status"] == "error": return {"status": "error", "message": "Api key is not valid"} @@ -319,87 +312,101 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams return {"status": "error", "message": "something went wrong while saving the api key"} +@contextlib.contextmanager +def execution_context(**kwargs: str) -> None: + """Temporarily set context values for the current async task.""" + # Create a fresh copy per use + current = {**server.execution_context_vars.get(), **kwargs} + token = server.execution_context_vars.set(current) + try: + yield + finally: + server.execution_context_vars.reset(token) + + @server.feature("initializeFunctionOptimization") -def initialize_function_optimization( - server: CodeflashLanguageServer, params: FunctionOptimizationInitParams -) -> dict[str, str]: - document_uri = params.textDocument.uri - document = server.workspace.get_text_document(document_uri) +def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]: + with execution_context(task_id=params.task_id): + document_uri = params.textDocument.uri + document = server.workspace.get_text_document(document_uri) - server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info") + server.show_message_log( + f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info" + ) - if server.optimizer is None: - _initialize_optimizer_if_api_key_is_valid(server) + if server.optimizer is None: + _initialize_optimizer_if_api_key_is_valid() - server.optimizer.worktree_mode() + server.optimizer.worktree_mode() - original_args, _ = server.optimizer.original_args_and_test_cfg + original_args, _ = server.optimizer.original_args_and_test_cfg - server.optimizer.args.function = params.functionName - original_relative_file_path = Path(document.path).relative_to(original_args.project_root) - server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path - server.optimizer.args.previous_checkpoint_functions = False + server.optimizer.args.function = params.functionName + original_relative_file_path = Path(document.path).relative_to(original_args.project_root) + server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path + server.optimizer.args.previous_checkpoint_functions = False - server.show_message_log( - f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" - ) + server.show_message_log( + f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" + ) - optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions() + optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions() - if count == 0: - server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning") - server.cleanup_the_optimizer() - return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None} + if count == 0: + server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning") + server.cleanup_the_optimizer() + return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None} - fto = optimizable_funcs.popitem()[1][0] + fto = optimizable_funcs.popitem()[1][0] - module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path) - if not module_prep_result: - return { - "functionName": params.functionName, - "status": "error", - "message": "Failed to prepare module for optimization", - } + module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path) + if not module_prep_result: + return { + "functionName": params.functionName, + "status": "error", + "message": "Failed to prepare module for optimization", + } - validated_original_code, original_module_ast = module_prep_result + validated_original_code, original_module_ast = module_prep_result - function_optimizer = server.optimizer.create_function_optimizer( - fto, - function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, - original_module_ast=original_module_ast, - original_module_path=fto.file_path, - function_to_tests={}, - ) + function_optimizer = server.optimizer.create_function_optimizer( + fto, + function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=fto.file_path, + function_to_tests={}, + ) - server.optimizer.current_function_optimizer = function_optimizer - if not function_optimizer: - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} - initialization_result = function_optimizer.can_be_optimized() - if not is_successful(initialization_result): - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} - server.current_optimization_init_result = initialization_result.unwrap() - server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info") + server.current_optimization_init_result = initialization_result.unwrap() + server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info") - files = [function_optimizer.function_to_optimize.file_path] + files = [function_optimizer.function_to_optimize.file_path] - _, _, original_helpers = server.current_optimization_init_result - files.extend([str(helper_path) for helper_path in original_helpers]) + _, _, original_helpers = server.current_optimization_init_result + files.extend([str(helper_path) for helper_path in original_helpers]) - return {"functionName": params.functionName, "status": "success", "files_inside_context": files} + return {"functionName": params.functionName, "status": "success", "files_inside_context": files} @server.feature("performFunctionOptimization") -async def perform_function_optimization( - server: CodeflashLanguageServer, params: FunctionOptimizationParams -) -> dict[str, str]: - loop = asyncio.get_running_loop() - try: - result = await loop.run_in_executor(None, sync_perform_optimization, server, params) - except asyncio.CancelledError: - return {"status": "canceled", "message": "Task was canceled"} - else: - return result - finally: - server.cleanup_the_optimizer() +async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]: + with execution_context(task_id=params.task_id): + loop = asyncio.get_running_loop() + server.cancel_event = threading.Event() + + try: + ctx = contextvars.copy_context() + return await loop.run_in_executor(None, ctx.run, sync_perform_optimization, params) + except asyncio.CancelledError: + server.cancel_event.set() + return get_cancelled_reponse() + finally: + server.cleanup_the_optimizer() diff --git a/codeflash/lsp/features/perform_optimization.py b/codeflash/lsp/features/perform_optimization.py index 0b16a32d4..c7648c7d9 100644 --- a/codeflash/lsp/features/perform_optimization.py +++ b/codeflash/lsp/features/perform_optimization.py @@ -1,13 +1,30 @@ +from __future__ import annotations + import contextlib import os +from typing import TYPE_CHECKING from codeflash.cli_cmds.console import code_print from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree from codeflash.either import is_successful -from codeflash.lsp.server import CodeflashLanguageServer +from codeflash.lsp.server import CodeflashServerSingleton + +if TYPE_CHECKING: + import threading + + +def get_cancelled_reponse() -> dict[str, str]: + return {"status": "canceled", "message": "Task was canceled"} + + +def abort_if_cancelled(cancel_event: threading.Event) -> None: + if cancel_event.is_set(): + raise RuntimeError("cancelled") -def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[str, str]: # noqa: ANN001 +def sync_perform_optimization(params) -> dict[str, str]: # noqa + server = CodeflashServerSingleton.get() + cancel_event = server.cancel_event server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result function_optimizer = server.optimizer.current_function_optimizer @@ -18,6 +35,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s file_name=current_function.file_path, function_name=current_function.function_name, ) + abort_if_cancelled(cancel_event) optimizable_funcs = {current_function.file_path: [current_function]} @@ -26,9 +44,11 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) function_optimizer.function_to_tests = function_to_tests + abort_if_cancelled(cancel_event) test_setup_result = function_optimizer.generate_and_instrument_tests( code_context, should_run_experiment=should_run_experiment ) + abort_if_cancelled(cancel_event) if not is_successful(test_setup_result): return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} ( @@ -52,6 +72,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s original_conftest_content=original_conftest_content, ) + abort_if_cancelled(cancel_event) if not is_successful(baseline_setup_result): return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} @@ -76,6 +97,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s concolic_test_str=concolic_test_str, ) + abort_if_cancelled(cancel_event) if not best_optimization: server.show_message_log( f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" @@ -93,6 +115,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name ) + abort_if_cancelled(cancel_event) if not patch_path: return { "functionName": params.functionName, diff --git a/codeflash/lsp/lsp_message.py b/codeflash/lsp/lsp_message.py index 214927857..ec4b00a6c 100644 --- a/codeflash/lsp/lsp_message.py +++ b/codeflash/lsp/lsp_message.py @@ -6,6 +6,7 @@ from typing import Any, Optional from codeflash.lsp.helpers import replace_quotes_with_backticks, simplify_worktree_paths +from codeflash.lsp.server import CodeflashServerSingleton json_primitive_types = (str, float, int, bool) max_code_lines_before_collapse = 45 @@ -34,8 +35,10 @@ def type(self) -> str: raise NotImplementedError def serialize(self) -> str: + lsp_server_instance = CodeflashServerSingleton.get() + current_task_id = lsp_server_instance.execution_context_vars.get().get("task_id", None) data = self._loop_through(asdict(self)) - ordered = {"type": self.type(), **data} + ordered = {"type": self.type(), "task_id": current_task_id, **data} return message_delimiter + json.dumps(ordered) + message_delimiter diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py index ffd2a4acb..32df41ebb 100644 --- a/codeflash/lsp/server.py +++ b/codeflash/lsp/server.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +import contextvars +from typing import TYPE_CHECKING from lsprotocol.types import LogMessageParams, MessageType from pygls.lsp.server import LanguageServer @@ -17,13 +18,33 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol): _server: CodeflashLanguageServer +class CodeflashServerSingleton: + _instance: CodeflashLanguageServer | None = None + + @classmethod + def get(cls) -> CodeflashLanguageServer: + if cls._instance is None: + cls._instance = CodeflashLanguageServer( + "codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol + ) + return cls._instance + + def __init__(self) -> None: + # This is a singleton class, so we don't want to initialize. + ... + + class CodeflashLanguageServer(LanguageServer): - def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 - super().__init__(*args, **kwargs) + def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerProtocol]) -> None: + super().__init__(name, version, protocol_cls=protocol_cls) self.optimizer: Optimizer | None = None self.args_processed_before: bool = False self.args = None self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None + self.execution_context_vars: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar( + "execution_context_vars", + default={}, # noqa: B039 + ) def prepare_optimizer_arguments(self, config_file: Path) -> None: from codeflash.cli_cmds.cli import parse_args