Skip to content

Commit dc6336a

Browse files
task execution context and abort checkpoints
1 parent 690e7b7 commit dc6336a

File tree

4 files changed

+152
-98
lines changed

4 files changed

+152
-98
lines changed

codeflash/lsp/beta.py

Lines changed: 99 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import asyncio
44
import contextlib
5+
import contextvars
56
import os
7+
import threading
68
from dataclasses import dataclass
79
from pathlib import Path
810
from typing import TYPE_CHECKING, Optional
@@ -27,8 +29,8 @@
2729
get_functions_within_git_diff,
2830
)
2931
from codeflash.either import is_successful
30-
from codeflash.lsp.features.perform_optimization import sync_perform_optimization
31-
from codeflash.lsp.server import CodeflashLanguageServer
32+
from codeflash.lsp.features.perform_optimization import get_cancelled_reponse, sync_perform_optimization
33+
from codeflash.lsp.server import CodeflashServerSingleton
3234

3335
if TYPE_CHECKING:
3436
from argparse import Namespace
@@ -47,6 +49,7 @@ class OptimizableFunctionsParams:
4749
class FunctionOptimizationInitParams:
4850
textDocument: types.TextDocumentIdentifier # noqa: N815
4951
functionName: str # noqa: N815
52+
task_id: str
5053

5154

5255
@dataclass
@@ -84,30 +87,24 @@ class WriteConfigParams:
8487
config: any
8588

8689

87-
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
90+
server = CodeflashServerSingleton.get()
8891

8992

9093
@server.feature("getOptimizableFunctionsInCurrentDiff")
91-
def get_functions_in_current_git_diff(
92-
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
93-
) -> dict[str, str | dict[str, list[str]]]:
94+
def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]:
9495
functions = get_functions_within_git_diff(uncommitted_changes=True)
95-
file_to_qualified_names = _group_functions_by_file(server, functions)
96+
file_to_qualified_names = _group_functions_by_file(functions)
9697
return {"functions": file_to_qualified_names, "status": "success"}
9798

9899

99100
@server.feature("getOptimizableFunctionsInCommit")
100-
def get_functions_in_commit(
101-
server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams
102-
) -> dict[str, str | dict[str, list[str]]]:
101+
def get_functions_in_commit(params: OptimizableFunctionsInCommitParams) -> dict[str, str | dict[str, list[str]]]:
103102
functions = get_functions_inside_a_commit(params.commit_hash)
104-
file_to_qualified_names = _group_functions_by_file(server, functions)
103+
file_to_qualified_names = _group_functions_by_file(functions)
105104
return {"functions": file_to_qualified_names, "status": "success"}
106105

107106

108-
def _group_functions_by_file(
109-
server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]]
110-
) -> dict[str, list[str]]:
107+
def _group_functions_by_file(functions: dict[str, list[FunctionToOptimize]]) -> dict[str, list[str]]:
111108
file_to_funcs_to_optimize, _ = filter_functions(
112109
modified_functions=functions,
113110
tests_root=server.optimizer.test_cfg.tests_root,
@@ -123,9 +120,7 @@ def _group_functions_by_file(
123120

124121

125122
@server.feature("getOptimizableFunctions")
126-
def get_optimizable_functions(
127-
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
128-
) -> dict[str, list[str]]:
123+
def get_optimizable_functions(params: OptimizableFunctionsParams) -> dict[str, list[str]]:
129124
document_uri = params.textDocument.uri
130125
document = server.workspace.get_text_document(document_uri)
131126

@@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
172167

173168

174169
@server.feature("writeConfig")
175-
def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]:
170+
def write_config(params: WriteConfigParams) -> dict[str, any]:
176171
cfg = params.config
177172
cfg_file = Path(params.config_file) if params.config_file else None
178173

@@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->
196191

197192

198193
@server.feature("getConfigSuggestions")
199-
def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]:
194+
def get_config_suggestions(_params: any) -> dict[str, any]:
200195
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
201196
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
202197
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
@@ -212,7 +207,7 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di
212207

213208
# should be called the first thing to initialize and validate the project
214209
@server.feature("initProject")
215-
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
210+
def init_project(params: ValidateProjectParams) -> dict[str, str]:
216211
# Always process args in the init project, the extension can call
217212
server.args_processed_before = False
218213

@@ -255,14 +250,12 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
255250
"existingConfig": config,
256251
}
257252

258-
args = process_args(server)
253+
args = process_args()
259254

260255
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
261256

262257

263-
def _initialize_optimizer_if_api_key_is_valid(
264-
server: CodeflashLanguageServer, api_key: Optional[str] = None
265-
) -> dict[str, str]:
258+
def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
266259
user_id = get_user_id(api_key=api_key)
267260
if user_id is None:
268261
return {"status": "error", "message": "api key not found or invalid"}
@@ -273,12 +266,12 @@ def _initialize_optimizer_if_api_key_is_valid(
273266

274267
from codeflash.optimization.optimizer import Optimizer
275268

276-
new_args = process_args(server)
269+
new_args = process_args()
277270
server.optimizer = Optimizer(new_args)
278271
return {"status": "success", "user_id": user_id}
279272

280273

281-
def process_args(server: CodeflashLanguageServer) -> Namespace:
274+
def process_args() -> Namespace:
282275
if server.args_processed_before:
283276
return server.args
284277
new_args = process_pyproject_config(server.args)
@@ -288,15 +281,15 @@ def process_args(server: CodeflashLanguageServer) -> Namespace:
288281

289282

290283
@server.feature("apiKeyExistsAndValid")
291-
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
284+
def check_api_key(_params: any) -> dict[str, str]:
292285
try:
293-
return _initialize_optimizer_if_api_key_is_valid(server)
286+
return _initialize_optimizer_if_api_key_is_valid()
294287
except Exception:
295288
return {"status": "error", "message": "something went wrong while validating the api key"}
296289

297290

298291
@server.feature("provideApiKey")
299-
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
292+
def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
300293
try:
301294
api_key = params.api_key
302295
if not api_key.startswith("cf-"):
@@ -306,7 +299,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
306299
get_codeflash_api_key.cache_clear()
307300
get_user_id.cache_clear()
308301

309-
init_result = _initialize_optimizer_if_api_key_is_valid(server, api_key)
302+
init_result = _initialize_optimizer_if_api_key_is_valid(api_key)
310303
if init_result["status"] == "error":
311304
return {"status": "error", "message": "Api key is not valid"}
312305

@@ -319,87 +312,101 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
319312
return {"status": "error", "message": "something went wrong while saving the api key"}
320313

321314

315+
@contextlib.contextmanager
316+
def execution_context(**kwargs: str) -> None:
317+
"""Temporarily set context values for the current async task."""
318+
# Create a fresh copy per use
319+
current = {**server.execution_context_vars.get(), **kwargs}
320+
token = server.execution_context_vars.set(current)
321+
try:
322+
yield
323+
finally:
324+
server.execution_context_vars.reset(token)
325+
326+
322327
@server.feature("initializeFunctionOptimization")
323-
def initialize_function_optimization(
324-
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
325-
) -> dict[str, str]:
326-
document_uri = params.textDocument.uri
327-
document = server.workspace.get_text_document(document_uri)
328+
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
329+
with execution_context(task_id=params.task_id):
330+
document_uri = params.textDocument.uri
331+
document = server.workspace.get_text_document(document_uri)
328332

329-
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")
333+
server.show_message_log(
334+
f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info"
335+
)
330336

331-
if server.optimizer is None:
332-
_initialize_optimizer_if_api_key_is_valid(server)
337+
if server.optimizer is None:
338+
_initialize_optimizer_if_api_key_is_valid()
333339

334-
server.optimizer.worktree_mode()
340+
server.optimizer.worktree_mode()
335341

336-
original_args, _ = server.optimizer.original_args_and_test_cfg
342+
original_args, _ = server.optimizer.original_args_and_test_cfg
337343

338-
server.optimizer.args.function = params.functionName
339-
original_relative_file_path = Path(document.path).relative_to(original_args.project_root)
340-
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
341-
server.optimizer.args.previous_checkpoint_functions = False
344+
server.optimizer.args.function = params.functionName
345+
original_relative_file_path = Path(document.path).relative_to(original_args.project_root)
346+
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
347+
server.optimizer.args.previous_checkpoint_functions = False
342348

343-
server.show_message_log(
344-
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
345-
)
349+
server.show_message_log(
350+
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
351+
)
346352

347-
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
353+
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
348354

349-
if count == 0:
350-
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
351-
server.cleanup_the_optimizer()
352-
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
355+
if count == 0:
356+
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
357+
server.cleanup_the_optimizer()
358+
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
353359

354-
fto = optimizable_funcs.popitem()[1][0]
360+
fto = optimizable_funcs.popitem()[1][0]
355361

356-
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
357-
if not module_prep_result:
358-
return {
359-
"functionName": params.functionName,
360-
"status": "error",
361-
"message": "Failed to prepare module for optimization",
362-
}
362+
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
363+
if not module_prep_result:
364+
return {
365+
"functionName": params.functionName,
366+
"status": "error",
367+
"message": "Failed to prepare module for optimization",
368+
}
363369

364-
validated_original_code, original_module_ast = module_prep_result
370+
validated_original_code, original_module_ast = module_prep_result
365371

366-
function_optimizer = server.optimizer.create_function_optimizer(
367-
fto,
368-
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
369-
original_module_ast=original_module_ast,
370-
original_module_path=fto.file_path,
371-
function_to_tests={},
372-
)
372+
function_optimizer = server.optimizer.create_function_optimizer(
373+
fto,
374+
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
375+
original_module_ast=original_module_ast,
376+
original_module_path=fto.file_path,
377+
function_to_tests={},
378+
)
373379

374-
server.optimizer.current_function_optimizer = function_optimizer
375-
if not function_optimizer:
376-
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
380+
server.optimizer.current_function_optimizer = function_optimizer
381+
if not function_optimizer:
382+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
377383

378-
initialization_result = function_optimizer.can_be_optimized()
379-
if not is_successful(initialization_result):
380-
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
384+
initialization_result = function_optimizer.can_be_optimized()
385+
if not is_successful(initialization_result):
386+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
381387

382-
server.current_optimization_init_result = initialization_result.unwrap()
383-
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
388+
server.current_optimization_init_result = initialization_result.unwrap()
389+
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
384390

385-
files = [function_optimizer.function_to_optimize.file_path]
391+
files = [function_optimizer.function_to_optimize.file_path]
386392

387-
_, _, original_helpers = server.current_optimization_init_result
388-
files.extend([str(helper_path) for helper_path in original_helpers])
393+
_, _, original_helpers = server.current_optimization_init_result
394+
files.extend([str(helper_path) for helper_path in original_helpers])
389395

390-
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
396+
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
391397

392398

393399
@server.feature("performFunctionOptimization")
394-
async def perform_function_optimization(
395-
server: CodeflashLanguageServer, params: FunctionOptimizationParams
396-
) -> dict[str, str]:
397-
loop = asyncio.get_running_loop()
398-
try:
399-
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
400-
except asyncio.CancelledError:
401-
return {"status": "canceled", "message": "Task was canceled"}
402-
else:
403-
return result
404-
finally:
405-
server.cleanup_the_optimizer()
400+
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
401+
with execution_context(task_id=params.task_id):
402+
loop = asyncio.get_running_loop()
403+
server.cancel_event = threading.Event()
404+
405+
try:
406+
ctx = contextvars.copy_context()
407+
return await loop.run_in_executor(None, ctx.run, sync_perform_optimization, params)
408+
except asyncio.CancelledError:
409+
server.cancel_event.set()
410+
return get_cancelled_reponse()
411+
finally:
412+
server.cleanup_the_optimizer()

0 commit comments

Comments
 (0)