Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 99 additions & 92 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import asyncio
import contextlib
import contextvars
import os
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
Expand All @@ -27,8 +29,8 @@
get_functions_within_git_diff,
)
from codeflash.either import is_successful
from codeflash.lsp.features.perform_optimization import sync_perform_optimization
from codeflash.lsp.server import CodeflashLanguageServer
from codeflash.lsp.features.perform_optimization import get_cancelled_reponse, sync_perform_optimization
from codeflash.lsp.server import CodeflashServerSingleton

if TYPE_CHECKING:
from argparse import Namespace
Expand All @@ -47,6 +49,7 @@ class OptimizableFunctionsParams:
class FunctionOptimizationInitParams:
textDocument: types.TextDocumentIdentifier # noqa: N815
functionName: str # noqa: N815
task_id: str


@dataclass
Expand Down Expand Up @@ -84,30 +87,24 @@ class WriteConfigParams:
config: any


server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
server = CodeflashServerSingleton.get()


@server.feature("getOptimizableFunctionsInCurrentDiff")
def get_functions_in_current_git_diff(
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
) -> dict[str, str | dict[str, list[str]]]:
def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]:
functions = get_functions_within_git_diff(uncommitted_changes=True)
file_to_qualified_names = _group_functions_by_file(server, functions)
file_to_qualified_names = _group_functions_by_file(functions)
return {"functions": file_to_qualified_names, "status": "success"}


@server.feature("getOptimizableFunctionsInCommit")
def get_functions_in_commit(
server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams
) -> dict[str, str | dict[str, list[str]]]:
def get_functions_in_commit(params: OptimizableFunctionsInCommitParams) -> dict[str, str | dict[str, list[str]]]:
functions = get_functions_inside_a_commit(params.commit_hash)
file_to_qualified_names = _group_functions_by_file(server, functions)
file_to_qualified_names = _group_functions_by_file(functions)
return {"functions": file_to_qualified_names, "status": "success"}


def _group_functions_by_file(
server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]]
) -> dict[str, list[str]]:
def _group_functions_by_file(functions: dict[str, list[FunctionToOptimize]]) -> dict[str, list[str]]:
file_to_funcs_to_optimize, _ = filter_functions(
modified_functions=functions,
tests_root=server.optimizer.test_cfg.tests_root,
Expand All @@ -123,9 +120,7 @@ def _group_functions_by_file(


@server.feature("getOptimizableFunctions")
def get_optimizable_functions(
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
) -> dict[str, list[str]]:
def get_optimizable_functions(params: OptimizableFunctionsParams) -> dict[str, list[str]]:
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)

Expand Down Expand Up @@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:


@server.feature("writeConfig")
def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]:
def write_config(params: WriteConfigParams) -> dict[str, any]:
cfg = params.config
cfg_file = Path(params.config_file) if params.config_file else None

Expand All @@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->


@server.feature("getConfigSuggestions")
def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]:
def get_config_suggestions(_params: any) -> dict[str, any]:
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
Expand All @@ -212,7 +207,7 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di

# should be called the first thing to initialize and validate the project
@server.feature("initProject")
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
def init_project(params: ValidateProjectParams) -> dict[str, str]:
# Always process args in the init project, the extension can call
server.args_processed_before = False

Expand Down Expand Up @@ -255,14 +250,12 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
"existingConfig": config,
}

args = process_args(server)
args = process_args()

return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}


def _initialize_optimizer_if_api_key_is_valid(
server: CodeflashLanguageServer, api_key: Optional[str] = None
) -> dict[str, str]:
def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
user_id = get_user_id(api_key=api_key)
if user_id is None:
return {"status": "error", "message": "api key not found or invalid"}
Expand All @@ -273,12 +266,12 @@ def _initialize_optimizer_if_api_key_is_valid(

from codeflash.optimization.optimizer import Optimizer

new_args = process_args(server)
new_args = process_args()
server.optimizer = Optimizer(new_args)
return {"status": "success", "user_id": user_id}


def process_args(server: CodeflashLanguageServer) -> Namespace:
def process_args() -> Namespace:
if server.args_processed_before:
return server.args
new_args = process_pyproject_config(server.args)
Expand All @@ -288,15 +281,15 @@ def process_args(server: CodeflashLanguageServer) -> Namespace:


@server.feature("apiKeyExistsAndValid")
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
def check_api_key(_params: any) -> dict[str, str]:
try:
return _initialize_optimizer_if_api_key_is_valid(server)
return _initialize_optimizer_if_api_key_is_valid()
except Exception:
return {"status": "error", "message": "something went wrong while validating the api key"}


@server.feature("provideApiKey")
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
try:
api_key = params.api_key
if not api_key.startswith("cf-"):
Expand All @@ -306,7 +299,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
get_codeflash_api_key.cache_clear()
get_user_id.cache_clear()

init_result = _initialize_optimizer_if_api_key_is_valid(server, api_key)
init_result = _initialize_optimizer_if_api_key_is_valid(api_key)
if init_result["status"] == "error":
return {"status": "error", "message": "Api key is not valid"}

Expand All @@ -319,87 +312,101 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
return {"status": "error", "message": "something went wrong while saving the api key"}


@contextlib.contextmanager
def execution_context(**kwargs: str) -> None:
"""Temporarily set context values for the current async task."""
# Create a fresh copy per use
current = {**server.execution_context_vars.get(), **kwargs}
token = server.execution_context_vars.set(current)
try:
yield
finally:
server.execution_context_vars.reset(token)


@server.feature("initializeFunctionOptimization")
def initialize_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
) -> dict[str, str]:
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
with execution_context(task_id=params.task_id):
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)

server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")
server.show_message_log(
f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info"
)

if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid(server)
if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid()

server.optimizer.worktree_mode()
server.optimizer.worktree_mode()

original_args, _ = server.optimizer.original_args_and_test_cfg
original_args, _ = server.optimizer.original_args_and_test_cfg

server.optimizer.args.function = params.functionName
original_relative_file_path = Path(document.path).relative_to(original_args.project_root)
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
server.optimizer.args.previous_checkpoint_functions = False
server.optimizer.args.function = params.functionName
original_relative_file_path = Path(document.path).relative_to(original_args.project_root)
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
server.optimizer.args.previous_checkpoint_functions = False

server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)
server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)

optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()

if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}

fto = optimizable_funcs.popitem()[1][0]
fto = optimizable_funcs.popitem()[1][0]

module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to prepare module for optimization",
}
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to prepare module for optimization",
}

validated_original_code, original_module_ast = module_prep_result
validated_original_code, original_module_ast = module_prep_result

function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)
function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)

server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}

initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}

server.current_optimization_init_result = initialization_result.unwrap()
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
server.current_optimization_init_result = initialization_result.unwrap()
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")

files = [function_optimizer.function_to_optimize.file_path]
files = [function_optimizer.function_to_optimize.file_path]

_, _, original_helpers = server.current_optimization_init_result
files.extend([str(helper_path) for helper_path in original_helpers])
_, _, original_helpers = server.current_optimization_init_result
files.extend([str(helper_path) for helper_path in original_helpers])

return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}


@server.feature("performFunctionOptimization")
async def perform_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
loop = asyncio.get_running_loop()
try:
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
except asyncio.CancelledError:
return {"status": "canceled", "message": "Task was canceled"}
else:
return result
finally:
server.cleanup_the_optimizer()
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
with execution_context(task_id=params.task_id):
loop = asyncio.get_running_loop()
server.cancel_event = threading.Event()

try:
ctx = contextvars.copy_context()
return await loop.run_in_executor(None, ctx.run, sync_perform_optimization, params)
except asyncio.CancelledError:
server.cancel_event.set()
return get_cancelled_reponse()
finally:
server.cleanup_the_optimizer()
Loading
Loading