22
33import asyncio
44import contextlib
5+ import contextvars
56import os
7+ import threading
68from dataclasses import dataclass
79from pathlib import Path
810from typing import TYPE_CHECKING , Optional
2729 get_functions_within_git_diff ,
2830)
2931from codeflash .either import is_successful
30- from codeflash .lsp .features .perform_optimization import sync_perform_optimization
31- from codeflash .lsp .server import CodeflashLanguageServer
32+ from codeflash .lsp .features .perform_optimization import get_cancelled_reponse , sync_perform_optimization
33+ from codeflash .lsp .server import CodeflashLanguageServer , CodeflashLanguageServerProtocol
3234
3335if TYPE_CHECKING :
3436 from argparse import Namespace
@@ -47,6 +49,7 @@ class OptimizableFunctionsParams:
4749class FunctionOptimizationInitParams :
4850 textDocument : types .TextDocumentIdentifier # noqa: N815
4951 functionName : str # noqa: N815
52+ task_id : str
5053
5154
5255@dataclass
@@ -84,30 +87,24 @@ class WriteConfigParams:
8487 config : any
8588
8689
87- server = CodeflashLanguageServer ("codeflash-language-server" , "v1.0" )
90+ server = CodeflashLanguageServer ("codeflash-language-server" , "v1.0" , protocol_cls = CodeflashLanguageServerProtocol )
8891
8992
9093@server .feature ("getOptimizableFunctionsInCurrentDiff" )
91- def get_functions_in_current_git_diff (
92- server : CodeflashLanguageServer , _params : OptimizableFunctionsParams
93- ) -> dict [str , str | dict [str , list [str ]]]:
94+ def get_functions_in_current_git_diff (_params : OptimizableFunctionsParams ) -> dict [str , str | dict [str , list [str ]]]:
9495 functions = get_functions_within_git_diff (uncommitted_changes = True )
95- file_to_qualified_names = _group_functions_by_file (server , functions )
96+ file_to_qualified_names = _group_functions_by_file (functions )
9697 return {"functions" : file_to_qualified_names , "status" : "success" }
9798
9899
99100@server .feature ("getOptimizableFunctionsInCommit" )
100- def get_functions_in_commit (
101- server : CodeflashLanguageServer , params : OptimizableFunctionsInCommitParams
102- ) -> dict [str , str | dict [str , list [str ]]]:
101+ def get_functions_in_commit (params : OptimizableFunctionsInCommitParams ) -> dict [str , str | dict [str , list [str ]]]:
103102 functions = get_functions_inside_a_commit (params .commit_hash )
104- file_to_qualified_names = _group_functions_by_file (server , functions )
103+ file_to_qualified_names = _group_functions_by_file (functions )
105104 return {"functions" : file_to_qualified_names , "status" : "success" }
106105
107106
108- def _group_functions_by_file (
109- server : CodeflashLanguageServer , functions : dict [str , list [FunctionToOptimize ]]
110- ) -> dict [str , list [str ]]:
107+ def _group_functions_by_file (functions : dict [str , list [FunctionToOptimize ]]) -> dict [str , list [str ]]:
111108 file_to_funcs_to_optimize , _ = filter_functions (
112109 modified_functions = functions ,
113110 tests_root = server .optimizer .test_cfg .tests_root ,
@@ -123,9 +120,7 @@ def _group_functions_by_file(
123120
124121
125122@server .feature ("getOptimizableFunctions" )
126- def get_optimizable_functions (
127- server : CodeflashLanguageServer , params : OptimizableFunctionsParams
128- ) -> dict [str , list [str ]]:
123+ def get_optimizable_functions (params : OptimizableFunctionsParams ) -> dict [str , list [str ]]:
129124 document_uri = params .textDocument .uri
130125 document = server .workspace .get_text_document (document_uri )
131126
@@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
172167
173168
174169@server .feature ("writeConfig" )
175- def write_config (_server : CodeflashLanguageServer , params : WriteConfigParams ) -> dict [str , any ]:
170+ def write_config (params : WriteConfigParams ) -> dict [str , any ]:
176171 cfg = params .config
177172 cfg_file = Path (params .config_file ) if params .config_file else None
178173
@@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->
196191
197192
198193@server .feature ("getConfigSuggestions" )
199- def get_config_suggestions (_server : CodeflashLanguageServer , _params : any ) -> dict [str , any ]:
194+ def get_config_suggestions (_params : any ) -> dict [str , any ]:
200195 module_root_suggestions , default_module_root = get_suggestions (CommonSections .module_root )
201196 tests_root_suggestions , default_tests_root = get_suggestions (CommonSections .tests_root )
202197 test_framework_suggestions , default_test_framework = get_suggestions (CommonSections .test_framework )
@@ -212,9 +207,9 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di
212207
213208# should be called the first thing to initialize and validate the project
214209@server .feature ("initProject" )
215- def init_project (server : CodeflashLanguageServer , params : ValidateProjectParams ) -> dict [str , str ]:
210+ def init_project (params : ValidateProjectParams ) -> dict [str , str ]:
216211 # Always process args in the init project, the extension can call
217- server .args_processed_before = False
212+ server .initialized = False
218213
219214 pyproject_toml_path : Path | None = getattr (params , "config_file" , None ) or getattr (server .args , "config_file" , None )
220215 if pyproject_toml_path is not None :
@@ -255,19 +250,16 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
255250 "existingConfig" : config ,
256251 }
257252
258- args = process_args (server )
259-
253+ args = _init ()
260254 return {"status" : "success" , "moduleRoot" : args .module_root , "pyprojectPath" : pyproject_toml_path , "root" : root }
261255
262256
263- def _initialize_optimizer_if_api_key_is_valid (
264- server : CodeflashLanguageServer , api_key : Optional [str ] = None
265- ) -> dict [str , str ]:
257+ def _initialize_optimizer_if_api_key_is_valid (api_key : Optional [str ] = None ) -> dict [str , str ]:
266258 key_check_result = _check_api_key_validity (api_key )
267259 if key_check_result .get ("status" ) != "success" :
268260 return key_check_result
269261
270- _initialize_optimizer ( server )
262+ _init ( )
271263 return key_check_result
272264
273265
@@ -283,134 +275,163 @@ def _check_api_key_validity(api_key: Optional[str]) -> dict[str, str]:
283275 return {"status" : "success" , "user_id" : user_id }
284276
285277
286- def _initialize_optimizer (server : CodeflashLanguageServer ) -> None :
278+ def _initialize_optimizer (args : Namespace ) -> None :
287279 from codeflash .optimization .optimizer import Optimizer
288280
289- new_args = process_args (server )
290281 if not server .optimizer :
291- server .optimizer = Optimizer (new_args )
282+ server .optimizer = Optimizer (args )
292283
293284
294- def process_args (server : CodeflashLanguageServer ) -> Namespace :
295- if server .args_processed_before :
296- return server .args
285+ def process_args () -> Namespace :
297286 new_args = process_pyproject_config (server .args )
298287 server .args = new_args
299- server .args_processed_before = True
288+ return new_args
289+
290+
291+ def _init () -> Namespace :
292+ if server .initialized :
293+ return server .args
294+ new_args = process_args ()
295+ _initialize_optimizer (new_args )
296+ server .initialized = True
300297 return new_args
301298
302299
303300@server .feature ("apiKeyExistsAndValid" )
304- def check_api_key (server : CodeflashLanguageServer , _params : any ) -> dict [str , str ]:
301+ def check_api_key (_params : any ) -> dict [str , str ]:
305302 try :
306- return _initialize_optimizer_if_api_key_is_valid (server )
303+ return _initialize_optimizer_if_api_key_is_valid ()
307304 except Exception :
308305 return {"status" : "error" , "message" : "something went wrong while validating the api key" }
309306
310307
311308@server .feature ("provideApiKey" )
312- def provide_api_key (server : CodeflashLanguageServer , params : ProvideApiKeyParams ) -> dict [str , str ]:
309+ def provide_api_key (params : ProvideApiKeyParams ) -> dict [str , str ]:
313310 try :
314311 api_key = params .api_key
315312 if not api_key .startswith ("cf-" ):
316313 return {"status" : "error" , "message" : "Api key is not valid" }
317314
318- # # clear cache to ensure the new api key is used
315+ # clear cache to ensure the new api key is used
319316 get_codeflash_api_key .cache_clear ()
320317 get_user_id .cache_clear ()
318+
321319 key_check_result = _check_api_key_validity (api_key )
322320 if key_check_result .get ("status" ) != "success" :
323321 return key_check_result
322+
324323 user_id = key_check_result ["user_id" ]
325324 result = save_api_key_to_rc (api_key )
325+
326326 # initialize optimizer with the new api key
327- _initialize_optimizer ( server )
327+ _init ( )
328328 if not is_successful (result ):
329329 return {"status" : "error" , "message" : result .failure ()}
330330 return {"status" : "success" , "message" : "Api key saved successfully" , "user_id" : user_id } # noqa: TRY300
331331 except Exception :
332332 return {"status" : "error" , "message" : "something went wrong while saving the api key" }
333333
334334
335+ @contextlib .contextmanager
336+ def execution_context (** kwargs : str ) -> None :
337+ """Temporarily set context values for the current async task."""
338+ # Create a fresh copy per use
339+ current = {** server .execution_context_vars .get (), ** kwargs }
340+ token = server .execution_context_vars .set (current )
341+ try :
342+ yield
343+ finally :
344+ server .execution_context_vars .reset (token )
345+
346+
347+ @server .feature ("cleanupCurrentOptimizerSession" )
348+ def cleanup_optimizer (_params : any ) -> dict [str , str ]:
349+ if not server .cleanup_the_optimizer ():
350+ return {"status" : "error" , "message" : "Failed to cleanup optimizer" }
351+ return {"status" : "success" }
352+
353+
335354@server .feature ("initializeFunctionOptimization" )
336- def initialize_function_optimization (
337- server : CodeflashLanguageServer , params : FunctionOptimizationInitParams
338- ) -> dict [str , str ]:
339- document_uri = params .textDocument .uri
340- document = server .workspace .get_text_document (document_uri )
341- file_path = Path (document .path )
355+ def initialize_function_optimization (params : FunctionOptimizationInitParams ) -> dict [str , str ]:
356+ with execution_context (task_id = params .task_id ):
357+ document_uri = params .textDocument .uri
358+ document = server .workspace .get_text_document (document_uri )
359+ file_path = Path (document .path )
342360
343- server .show_message_log (f"Initializing optimization for function: { params .functionName } in { document_uri } " , "Info" )
361+ server .show_message_log (
362+ f"Initializing optimization for function: { params .functionName } in { document_uri } " , "Info"
363+ )
344364
345- if server .optimizer is None :
346- _initialize_optimizer_if_api_key_is_valid (server )
365+ if server .optimizer is None :
366+ _initialize_optimizer_if_api_key_is_valid ()
347367
348- server .optimizer .args .file = file_path
349- server .optimizer .args .function = params .functionName
350- server .optimizer .args .previous_checkpoint_functions = False
368+ server .optimizer .args .file = file_path
369+ server .optimizer .args .function = params .functionName
370+ server .optimizer .args .previous_checkpoint_functions = False
351371
352- server .optimizer .worktree_mode ()
372+ server .optimizer .worktree_mode ()
353373
354- server .show_message_log (
355- f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
356- )
374+ server .show_message_log (
375+ f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
376+ )
357377
358- optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
378+ optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
359379
360- if count == 0 :
361- server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
362- server .cleanup_the_optimizer ()
363- return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
380+ if count == 0 :
381+ server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
382+ server .cleanup_the_optimizer ()
383+ return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
364384
365- fto = optimizable_funcs .popitem ()[1 ][0 ]
385+ fto = optimizable_funcs .popitem ()[1 ][0 ]
366386
367- module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
368- if not module_prep_result :
369- return {
370- "functionName" : params .functionName ,
371- "status" : "error" ,
372- "message" : "Failed to prepare module for optimization" ,
373- }
387+ module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
388+ if not module_prep_result :
389+ return {
390+ "functionName" : params .functionName ,
391+ "status" : "error" ,
392+ "message" : "Failed to prepare module for optimization" ,
393+ }
374394
375- validated_original_code , original_module_ast = module_prep_result
395+ validated_original_code , original_module_ast = module_prep_result
376396
377- function_optimizer = server .optimizer .create_function_optimizer (
378- fto ,
379- function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
380- original_module_ast = original_module_ast ,
381- original_module_path = fto .file_path ,
382- function_to_tests = {},
383- )
397+ function_optimizer = server .optimizer .create_function_optimizer (
398+ fto ,
399+ function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
400+ original_module_ast = original_module_ast ,
401+ original_module_path = fto .file_path ,
402+ function_to_tests = {},
403+ )
384404
385- server .optimizer .current_function_optimizer = function_optimizer
386- if not function_optimizer :
387- return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
405+ server .optimizer .current_function_optimizer = function_optimizer
406+ if not function_optimizer :
407+ return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
388408
389- initialization_result = function_optimizer .can_be_optimized ()
390- if not is_successful (initialization_result ):
391- return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
409+ initialization_result = function_optimizer .can_be_optimized ()
410+ if not is_successful (initialization_result ):
411+ return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
392412
393- server .current_optimization_init_result = initialization_result .unwrap ()
394- server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
413+ server .current_optimization_init_result = initialization_result .unwrap ()
414+ server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
395415
396- files = [function_optimizer .function_to_optimize .file_path ]
416+ files = [function_optimizer .function_to_optimize .file_path ]
397417
398- _ , _ , original_helpers = server .current_optimization_init_result
399- files .extend ([str (helper_path ) for helper_path in original_helpers ])
418+ _ , _ , original_helpers = server .current_optimization_init_result
419+ files .extend ([str (helper_path ) for helper_path in original_helpers ])
400420
401- return {"functionName" : params .functionName , "status" : "success" , "files_inside_context" : files }
421+ return {"functionName" : params .functionName , "status" : "success" , "files_inside_context" : files }
402422
403423
404424@server .feature ("performFunctionOptimization" )
405- async def perform_function_optimization (
406- server : CodeflashLanguageServer , params : FunctionOptimizationParams
407- ) -> dict [str , str ]:
408- loop = asyncio .get_running_loop ()
409- try :
410- result = await loop .run_in_executor (None , sync_perform_optimization , server , params )
411- except asyncio .CancelledError :
412- return {"status" : "canceled" , "message" : "Task was canceled" }
413- else :
414- return result
415- finally :
416- server .cleanup_the_optimizer ()
425+ async def perform_function_optimization (params : FunctionOptimizationParams ) -> dict [str , str ]:
426+ with execution_context (task_id = params .task_id ):
427+ loop = asyncio .get_running_loop ()
428+ cancel_event = threading .Event ()
429+
430+ try :
431+ ctx = contextvars .copy_context ()
432+ return await loop .run_in_executor (None , ctx .run , sync_perform_optimization , server , cancel_event , params )
433+ except asyncio .CancelledError :
434+ cancel_event .set ()
435+ return get_cancelled_reponse ()
436+ finally :
437+ server .cleanup_the_optimizer ()
0 commit comments