Skip to content

Commit 0b425b9

Browse files
authored
Merge branch 'main' into generated-tests-markdown
2 parents eedd440 + fdd7ca9 commit 0b425b9

File tree

5 files changed

+171
-110
lines changed

5 files changed

+171
-110
lines changed

codeflash/lsp/beta.py

Lines changed: 121 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import asyncio
44
import contextlib
5+
import contextvars
56
import os
7+
import threading
68
from dataclasses import dataclass
79
from pathlib import Path
810
from typing import TYPE_CHECKING, Optional
@@ -27,8 +29,8 @@
2729
get_functions_within_git_diff,
2830
)
2931
from codeflash.either import is_successful
30-
from codeflash.lsp.features.perform_optimization import sync_perform_optimization
31-
from codeflash.lsp.server import CodeflashLanguageServer
32+
from codeflash.lsp.features.perform_optimization import get_cancelled_reponse, sync_perform_optimization
33+
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
3234

3335
if TYPE_CHECKING:
3436
from argparse import Namespace
@@ -47,6 +49,7 @@ class OptimizableFunctionsParams:
4749
class FunctionOptimizationInitParams:
4850
textDocument: types.TextDocumentIdentifier # noqa: N815
4951
functionName: str # noqa: N815
52+
task_id: str
5053

5154

5255
@dataclass
@@ -84,30 +87,24 @@ class WriteConfigParams:
8487
config: any
8588

8689

87-
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
90+
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
8891

8992

9093
@server.feature("getOptimizableFunctionsInCurrentDiff")
91-
def get_functions_in_current_git_diff(
92-
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
93-
) -> dict[str, str | dict[str, list[str]]]:
94+
def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]:
9495
functions = get_functions_within_git_diff(uncommitted_changes=True)
95-
file_to_qualified_names = _group_functions_by_file(server, functions)
96+
file_to_qualified_names = _group_functions_by_file(functions)
9697
return {"functions": file_to_qualified_names, "status": "success"}
9798

9899

99100
@server.feature("getOptimizableFunctionsInCommit")
100-
def get_functions_in_commit(
101-
server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams
102-
) -> dict[str, str | dict[str, list[str]]]:
101+
def get_functions_in_commit(params: OptimizableFunctionsInCommitParams) -> dict[str, str | dict[str, list[str]]]:
103102
functions = get_functions_inside_a_commit(params.commit_hash)
104-
file_to_qualified_names = _group_functions_by_file(server, functions)
103+
file_to_qualified_names = _group_functions_by_file(functions)
105104
return {"functions": file_to_qualified_names, "status": "success"}
106105

107106

108-
def _group_functions_by_file(
109-
server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]]
110-
) -> dict[str, list[str]]:
107+
def _group_functions_by_file(functions: dict[str, list[FunctionToOptimize]]) -> dict[str, list[str]]:
111108
file_to_funcs_to_optimize, _ = filter_functions(
112109
modified_functions=functions,
113110
tests_root=server.optimizer.test_cfg.tests_root,
@@ -123,9 +120,7 @@ def _group_functions_by_file(
123120

124121

125122
@server.feature("getOptimizableFunctions")
126-
def get_optimizable_functions(
127-
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
128-
) -> dict[str, list[str]]:
123+
def get_optimizable_functions(params: OptimizableFunctionsParams) -> dict[str, list[str]]:
129124
document_uri = params.textDocument.uri
130125
document = server.workspace.get_text_document(document_uri)
131126

@@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
172167

173168

174169
@server.feature("writeConfig")
175-
def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]:
170+
def write_config(params: WriteConfigParams) -> dict[str, any]:
176171
cfg = params.config
177172
cfg_file = Path(params.config_file) if params.config_file else None
178173

@@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->
196191

197192

198193
@server.feature("getConfigSuggestions")
199-
def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]:
194+
def get_config_suggestions(_params: any) -> dict[str, any]:
200195
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
201196
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
202197
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
@@ -212,9 +207,9 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di
212207

213208
# should be called the first thing to initialize and validate the project
214209
@server.feature("initProject")
215-
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
210+
def init_project(params: ValidateProjectParams) -> dict[str, str]:
216211
# Always process args in the init project, the extension can call
217-
server.args_processed_before = False
212+
server.initialized = False
218213

219214
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)
220215
if pyproject_toml_path is not None:
@@ -255,19 +250,16 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
255250
"existingConfig": config,
256251
}
257252

258-
args = process_args(server)
259-
253+
args = _init()
260254
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
261255

262256

263-
def _initialize_optimizer_if_api_key_is_valid(
264-
server: CodeflashLanguageServer, api_key: Optional[str] = None
265-
) -> dict[str, str]:
257+
def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
266258
key_check_result = _check_api_key_validity(api_key)
267259
if key_check_result.get("status") != "success":
268260
return key_check_result
269261

270-
_initialize_optimizer(server)
262+
_init()
271263
return key_check_result
272264

273265

@@ -283,134 +275,163 @@ def _check_api_key_validity(api_key: Optional[str]) -> dict[str, str]:
283275
return {"status": "success", "user_id": user_id}
284276

285277

286-
def _initialize_optimizer(server: CodeflashLanguageServer) -> None:
278+
def _initialize_optimizer(args: Namespace) -> None:
287279
from codeflash.optimization.optimizer import Optimizer
288280

289-
new_args = process_args(server)
290281
if not server.optimizer:
291-
server.optimizer = Optimizer(new_args)
282+
server.optimizer = Optimizer(args)
292283

293284

294-
def process_args(server: CodeflashLanguageServer) -> Namespace:
295-
if server.args_processed_before:
296-
return server.args
285+
def process_args() -> Namespace:
297286
new_args = process_pyproject_config(server.args)
298287
server.args = new_args
299-
server.args_processed_before = True
288+
return new_args
289+
290+
291+
def _init() -> Namespace:
292+
if server.initialized:
293+
return server.args
294+
new_args = process_args()
295+
_initialize_optimizer(new_args)
296+
server.initialized = True
300297
return new_args
301298

302299

303300
@server.feature("apiKeyExistsAndValid")
304-
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
301+
def check_api_key(_params: any) -> dict[str, str]:
305302
try:
306-
return _initialize_optimizer_if_api_key_is_valid(server)
303+
return _initialize_optimizer_if_api_key_is_valid()
307304
except Exception:
308305
return {"status": "error", "message": "something went wrong while validating the api key"}
309306

310307

311308
@server.feature("provideApiKey")
312-
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
309+
def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
313310
try:
314311
api_key = params.api_key
315312
if not api_key.startswith("cf-"):
316313
return {"status": "error", "message": "Api key is not valid"}
317314

318-
# # clear cache to ensure the new api key is used
315+
# clear cache to ensure the new api key is used
319316
get_codeflash_api_key.cache_clear()
320317
get_user_id.cache_clear()
318+
321319
key_check_result = _check_api_key_validity(api_key)
322320
if key_check_result.get("status") != "success":
323321
return key_check_result
322+
324323
user_id = key_check_result["user_id"]
325324
result = save_api_key_to_rc(api_key)
325+
326326
# initialize optimizer with the new api key
327-
_initialize_optimizer(server)
327+
_init()
328328
if not is_successful(result):
329329
return {"status": "error", "message": result.failure()}
330330
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
331331
except Exception:
332332
return {"status": "error", "message": "something went wrong while saving the api key"}
333333

334334

335+
@contextlib.contextmanager
336+
def execution_context(**kwargs: str) -> None:
337+
"""Temporarily set context values for the current async task."""
338+
# Create a fresh copy per use
339+
current = {**server.execution_context_vars.get(), **kwargs}
340+
token = server.execution_context_vars.set(current)
341+
try:
342+
yield
343+
finally:
344+
server.execution_context_vars.reset(token)
345+
346+
347+
@server.feature("cleanupCurrentOptimizerSession")
348+
def cleanup_optimizer(_params: any) -> dict[str, str]:
349+
if not server.cleanup_the_optimizer():
350+
return {"status": "error", "message": "Failed to cleanup optimizer"}
351+
return {"status": "success"}
352+
353+
335354
@server.feature("initializeFunctionOptimization")
336-
def initialize_function_optimization(
337-
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
338-
) -> dict[str, str]:
339-
document_uri = params.textDocument.uri
340-
document = server.workspace.get_text_document(document_uri)
341-
file_path = Path(document.path)
355+
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
356+
with execution_context(task_id=params.task_id):
357+
document_uri = params.textDocument.uri
358+
document = server.workspace.get_text_document(document_uri)
359+
file_path = Path(document.path)
342360

343-
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")
361+
server.show_message_log(
362+
f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info"
363+
)
344364

345-
if server.optimizer is None:
346-
_initialize_optimizer_if_api_key_is_valid(server)
365+
if server.optimizer is None:
366+
_initialize_optimizer_if_api_key_is_valid()
347367

348-
server.optimizer.args.file = file_path
349-
server.optimizer.args.function = params.functionName
350-
server.optimizer.args.previous_checkpoint_functions = False
368+
server.optimizer.args.file = file_path
369+
server.optimizer.args.function = params.functionName
370+
server.optimizer.args.previous_checkpoint_functions = False
351371

352-
server.optimizer.worktree_mode()
372+
server.optimizer.worktree_mode()
353373

354-
server.show_message_log(
355-
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
356-
)
374+
server.show_message_log(
375+
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
376+
)
357377

358-
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
378+
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
359379

360-
if count == 0:
361-
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
362-
server.cleanup_the_optimizer()
363-
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
380+
if count == 0:
381+
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
382+
server.cleanup_the_optimizer()
383+
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
364384

365-
fto = optimizable_funcs.popitem()[1][0]
385+
fto = optimizable_funcs.popitem()[1][0]
366386

367-
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
368-
if not module_prep_result:
369-
return {
370-
"functionName": params.functionName,
371-
"status": "error",
372-
"message": "Failed to prepare module for optimization",
373-
}
387+
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
388+
if not module_prep_result:
389+
return {
390+
"functionName": params.functionName,
391+
"status": "error",
392+
"message": "Failed to prepare module for optimization",
393+
}
374394

375-
validated_original_code, original_module_ast = module_prep_result
395+
validated_original_code, original_module_ast = module_prep_result
376396

377-
function_optimizer = server.optimizer.create_function_optimizer(
378-
fto,
379-
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
380-
original_module_ast=original_module_ast,
381-
original_module_path=fto.file_path,
382-
function_to_tests={},
383-
)
397+
function_optimizer = server.optimizer.create_function_optimizer(
398+
fto,
399+
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
400+
original_module_ast=original_module_ast,
401+
original_module_path=fto.file_path,
402+
function_to_tests={},
403+
)
384404

385-
server.optimizer.current_function_optimizer = function_optimizer
386-
if not function_optimizer:
387-
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
405+
server.optimizer.current_function_optimizer = function_optimizer
406+
if not function_optimizer:
407+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
388408

389-
initialization_result = function_optimizer.can_be_optimized()
390-
if not is_successful(initialization_result):
391-
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
409+
initialization_result = function_optimizer.can_be_optimized()
410+
if not is_successful(initialization_result):
411+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
392412

393-
server.current_optimization_init_result = initialization_result.unwrap()
394-
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
413+
server.current_optimization_init_result = initialization_result.unwrap()
414+
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
395415

396-
files = [function_optimizer.function_to_optimize.file_path]
416+
files = [function_optimizer.function_to_optimize.file_path]
397417

398-
_, _, original_helpers = server.current_optimization_init_result
399-
files.extend([str(helper_path) for helper_path in original_helpers])
418+
_, _, original_helpers = server.current_optimization_init_result
419+
files.extend([str(helper_path) for helper_path in original_helpers])
400420

401-
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
421+
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
402422

403423

404424
@server.feature("performFunctionOptimization")
405-
async def perform_function_optimization(
406-
server: CodeflashLanguageServer, params: FunctionOptimizationParams
407-
) -> dict[str, str]:
408-
loop = asyncio.get_running_loop()
409-
try:
410-
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
411-
except asyncio.CancelledError:
412-
return {"status": "canceled", "message": "Task was canceled"}
413-
else:
414-
return result
415-
finally:
416-
server.cleanup_the_optimizer()
425+
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
426+
with execution_context(task_id=params.task_id):
427+
loop = asyncio.get_running_loop()
428+
cancel_event = threading.Event()
429+
430+
try:
431+
ctx = contextvars.copy_context()
432+
return await loop.run_in_executor(None, ctx.run, sync_perform_optimization, server, cancel_event, params)
433+
except asyncio.CancelledError:
434+
cancel_event.set()
435+
return get_cancelled_reponse()
436+
finally:
437+
server.cleanup_the_optimizer()

0 commit comments

Comments
 (0)