22
33import asyncio
44import contextlib
5+ import contextvars
56import os
7+ import threading
68from dataclasses import dataclass
79from pathlib import Path
810from typing import TYPE_CHECKING , Optional
2729 get_functions_within_git_diff ,
2830)
2931from codeflash .either import is_successful
30- from codeflash .lsp .features .perform_optimization import sync_perform_optimization
31- from codeflash .lsp .server import CodeflashLanguageServer
32+ from codeflash .lsp .features .perform_optimization import get_cancelled_reponse , sync_perform_optimization
33+ from codeflash .lsp .server import CodeflashServerSingleton
3234
3335if TYPE_CHECKING :
3436 from argparse import Namespace
@@ -47,6 +49,7 @@ class OptimizableFunctionsParams:
4749class FunctionOptimizationInitParams :
4850 textDocument : types .TextDocumentIdentifier # noqa: N815
4951 functionName : str # noqa: N815
52+ task_id : str
5053
5154
5255@dataclass
@@ -84,30 +87,24 @@ class WriteConfigParams:
8487 config : any
8588
8689
87- server = CodeflashLanguageServer ( "codeflash-language-server" , "v1.0" )
90+ server = CodeflashServerSingleton . get ( )
8891
8992
9093@server .feature ("getOptimizableFunctionsInCurrentDiff" )
91- def get_functions_in_current_git_diff (
92- server : CodeflashLanguageServer , _params : OptimizableFunctionsParams
93- ) -> dict [str , str | dict [str , list [str ]]]:
94+ def get_functions_in_current_git_diff (_params : OptimizableFunctionsParams ) -> dict [str , str | dict [str , list [str ]]]:
9495 functions = get_functions_within_git_diff (uncommitted_changes = True )
95- file_to_qualified_names = _group_functions_by_file (server , functions )
96+ file_to_qualified_names = _group_functions_by_file (functions )
9697 return {"functions" : file_to_qualified_names , "status" : "success" }
9798
9899
99100@server .feature ("getOptimizableFunctionsInCommit" )
100- def get_functions_in_commit (
101- server : CodeflashLanguageServer , params : OptimizableFunctionsInCommitParams
102- ) -> dict [str , str | dict [str , list [str ]]]:
101+ def get_functions_in_commit (params : OptimizableFunctionsInCommitParams ) -> dict [str , str | dict [str , list [str ]]]:
103102 functions = get_functions_inside_a_commit (params .commit_hash )
104- file_to_qualified_names = _group_functions_by_file (server , functions )
103+ file_to_qualified_names = _group_functions_by_file (functions )
105104 return {"functions" : file_to_qualified_names , "status" : "success" }
106105
107106
108- def _group_functions_by_file (
109- server : CodeflashLanguageServer , functions : dict [str , list [FunctionToOptimize ]]
110- ) -> dict [str , list [str ]]:
107+ def _group_functions_by_file (functions : dict [str , list [FunctionToOptimize ]]) -> dict [str , list [str ]]:
111108 file_to_funcs_to_optimize , _ = filter_functions (
112109 modified_functions = functions ,
113110 tests_root = server .optimizer .test_cfg .tests_root ,
@@ -123,9 +120,7 @@ def _group_functions_by_file(
123120
124121
125122@server .feature ("getOptimizableFunctions" )
126- def get_optimizable_functions (
127- server : CodeflashLanguageServer , params : OptimizableFunctionsParams
128- ) -> dict [str , list [str ]]:
123+ def get_optimizable_functions (params : OptimizableFunctionsParams ) -> dict [str , list [str ]]:
129124 document_uri = params .textDocument .uri
130125 document = server .workspace .get_text_document (document_uri )
131126
@@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
172167
173168
174169@server .feature ("writeConfig" )
175- def write_config (_server : CodeflashLanguageServer , params : WriteConfigParams ) -> dict [str , any ]:
170+ def write_config (params : WriteConfigParams ) -> dict [str , any ]:
176171 cfg = params .config
177172 cfg_file = Path (params .config_file ) if params .config_file else None
178173
@@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->
196191
197192
198193@server .feature ("getConfigSuggestions" )
199- def get_config_suggestions (_server : CodeflashLanguageServer , _params : any ) -> dict [str , any ]:
194+ def get_config_suggestions (_params : any ) -> dict [str , any ]:
200195 module_root_suggestions , default_module_root = get_suggestions (CommonSections .module_root )
201196 tests_root_suggestions , default_tests_root = get_suggestions (CommonSections .tests_root )
202197 test_framework_suggestions , default_test_framework = get_suggestions (CommonSections .test_framework )
@@ -212,7 +207,7 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di
212207
213208# should be called the first thing to initialize and validate the project
214209@server .feature ("initProject" )
215- def init_project (server : CodeflashLanguageServer , params : ValidateProjectParams ) -> dict [str , str ]:
210+ def init_project (params : ValidateProjectParams ) -> dict [str , str ]:
216211 # Always process args in the init project, the extension can call
217212 server .args_processed_before = False
218213
@@ -255,14 +250,12 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
255250 "existingConfig" : config ,
256251 }
257252
258- args = process_args (server )
253+ args = process_args ()
259254
260255 return {"status" : "success" , "moduleRoot" : args .module_root , "pyprojectPath" : pyproject_toml_path , "root" : root }
261256
262257
263- def _initialize_optimizer_if_api_key_is_valid (
264- server : CodeflashLanguageServer , api_key : Optional [str ] = None
265- ) -> dict [str , str ]:
258+ def _initialize_optimizer_if_api_key_is_valid (api_key : Optional [str ] = None ) -> dict [str , str ]:
266259 user_id = get_user_id (api_key = api_key )
267260 if user_id is None :
268261 return {"status" : "error" , "message" : "api key not found or invalid" }
@@ -273,12 +266,12 @@ def _initialize_optimizer_if_api_key_is_valid(
273266
274267 from codeflash .optimization .optimizer import Optimizer
275268
276- new_args = process_args (server )
269+ new_args = process_args ()
277270 server .optimizer = Optimizer (new_args )
278271 return {"status" : "success" , "user_id" : user_id }
279272
280273
281- def process_args (server : CodeflashLanguageServer ) -> Namespace :
274+ def process_args () -> Namespace :
282275 if server .args_processed_before :
283276 return server .args
284277 new_args = process_pyproject_config (server .args )
@@ -288,15 +281,15 @@ def process_args(server: CodeflashLanguageServer) -> Namespace:
288281
289282
290283@server .feature ("apiKeyExistsAndValid" )
291- def check_api_key (server : CodeflashLanguageServer , _params : any ) -> dict [str , str ]:
284+ def check_api_key (_params : any ) -> dict [str , str ]:
292285 try :
293- return _initialize_optimizer_if_api_key_is_valid (server )
286+ return _initialize_optimizer_if_api_key_is_valid ()
294287 except Exception :
295288 return {"status" : "error" , "message" : "something went wrong while validating the api key" }
296289
297290
298291@server .feature ("provideApiKey" )
299- def provide_api_key (server : CodeflashLanguageServer , params : ProvideApiKeyParams ) -> dict [str , str ]:
292+ def provide_api_key (params : ProvideApiKeyParams ) -> dict [str , str ]:
300293 try :
301294 api_key = params .api_key
302295 if not api_key .startswith ("cf-" ):
@@ -306,7 +299,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
306299 get_codeflash_api_key .cache_clear ()
307300 get_user_id .cache_clear ()
308301
309- init_result = _initialize_optimizer_if_api_key_is_valid (server , api_key )
302+ init_result = _initialize_optimizer_if_api_key_is_valid (api_key )
310303 if init_result ["status" ] == "error" :
311304 return {"status" : "error" , "message" : "Api key is not valid" }
312305
@@ -319,87 +312,101 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
319312 return {"status" : "error" , "message" : "something went wrong while saving the api key" }
320313
321314
315+ @contextlib .contextmanager
316+ def execution_context (** kwargs : str ) -> None :
317+ """Temporarily set context values for the current async task."""
318+ # Create a fresh copy per use
319+ current = {** server .execution_context_vars .get (), ** kwargs }
320+ token = server .execution_context_vars .set (current )
321+ try :
322+ yield
323+ finally :
324+ server .execution_context_vars .reset (token )
325+
326+
322327@server .feature ("initializeFunctionOptimization" )
323- def initialize_function_optimization (
324- server : CodeflashLanguageServer , params : FunctionOptimizationInitParams
325- ) -> dict [str , str ]:
326- document_uri = params .textDocument .uri
327- document = server .workspace .get_text_document (document_uri )
328+ def initialize_function_optimization (params : FunctionOptimizationInitParams ) -> dict [str , str ]:
329+ with execution_context (task_id = params .task_id ):
330+ document_uri = params .textDocument .uri
331+ document = server .workspace .get_text_document (document_uri )
328332
329- server .show_message_log (f"Initializing optimization for function: { params .functionName } in { document_uri } " , "Info" )
333+ server .show_message_log (
334+ f"Initializing optimization for function: { params .functionName } in { document_uri } " , "Info"
335+ )
330336
331- if server .optimizer is None :
332- _initialize_optimizer_if_api_key_is_valid (server )
337+ if server .optimizer is None :
338+ _initialize_optimizer_if_api_key_is_valid ()
333339
334- server .optimizer .worktree_mode ()
340+ server .optimizer .worktree_mode ()
335341
336- original_args , _ = server .optimizer .original_args_and_test_cfg
342+ original_args , _ = server .optimizer .original_args_and_test_cfg
337343
338- server .optimizer .args .function = params .functionName
339- original_relative_file_path = Path (document .path ).relative_to (original_args .project_root )
340- server .optimizer .args .file = server .optimizer .current_worktree / original_relative_file_path
341- server .optimizer .args .previous_checkpoint_functions = False
344+ server .optimizer .args .function = params .functionName
345+ original_relative_file_path = Path (document .path ).relative_to (original_args .project_root )
346+ server .optimizer .args .file = server .optimizer .current_worktree / original_relative_file_path
347+ server .optimizer .args .previous_checkpoint_functions = False
342348
343- server .show_message_log (
344- f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
345- )
349+ server .show_message_log (
350+ f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
351+ )
346352
347- optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
353+ optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
348354
349- if count == 0 :
350- server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
351- server .cleanup_the_optimizer ()
352- return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
355+ if count == 0 :
356+ server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
357+ server .cleanup_the_optimizer ()
358+ return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
353359
354- fto = optimizable_funcs .popitem ()[1 ][0 ]
360+ fto = optimizable_funcs .popitem ()[1 ][0 ]
355361
356- module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
357- if not module_prep_result :
358- return {
359- "functionName" : params .functionName ,
360- "status" : "error" ,
361- "message" : "Failed to prepare module for optimization" ,
362- }
362+ module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
363+ if not module_prep_result :
364+ return {
365+ "functionName" : params .functionName ,
366+ "status" : "error" ,
367+ "message" : "Failed to prepare module for optimization" ,
368+ }
363369
364- validated_original_code , original_module_ast = module_prep_result
370+ validated_original_code , original_module_ast = module_prep_result
365371
366- function_optimizer = server .optimizer .create_function_optimizer (
367- fto ,
368- function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
369- original_module_ast = original_module_ast ,
370- original_module_path = fto .file_path ,
371- function_to_tests = {},
372- )
372+ function_optimizer = server .optimizer .create_function_optimizer (
373+ fto ,
374+ function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
375+ original_module_ast = original_module_ast ,
376+ original_module_path = fto .file_path ,
377+ function_to_tests = {},
378+ )
373379
374- server .optimizer .current_function_optimizer = function_optimizer
375- if not function_optimizer :
376- return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
380+ server .optimizer .current_function_optimizer = function_optimizer
381+ if not function_optimizer :
382+ return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
377383
378- initialization_result = function_optimizer .can_be_optimized ()
379- if not is_successful (initialization_result ):
380- return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
384+ initialization_result = function_optimizer .can_be_optimized ()
385+ if not is_successful (initialization_result ):
386+ return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
381387
382- server .current_optimization_init_result = initialization_result .unwrap ()
383- server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
388+ server .current_optimization_init_result = initialization_result .unwrap ()
389+ server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
384390
385- files = [function_optimizer .function_to_optimize .file_path ]
391+ files = [function_optimizer .function_to_optimize .file_path ]
386392
387- _ , _ , original_helpers = server .current_optimization_init_result
388- files .extend ([str (helper_path ) for helper_path in original_helpers ])
393+ _ , _ , original_helpers = server .current_optimization_init_result
394+ files .extend ([str (helper_path ) for helper_path in original_helpers ])
389395
390- return {"functionName" : params .functionName , "status" : "success" , "files_inside_context" : files }
396+ return {"functionName" : params .functionName , "status" : "success" , "files_inside_context" : files }
391397
392398
393399@server .feature ("performFunctionOptimization" )
394- async def perform_function_optimization (
395- server : CodeflashLanguageServer , params : FunctionOptimizationParams
396- ) -> dict [str , str ]:
397- loop = asyncio .get_running_loop ()
398- try :
399- result = await loop .run_in_executor (None , sync_perform_optimization , server , params )
400- except asyncio .CancelledError :
401- return {"status" : "canceled" , "message" : "Task was canceled" }
402- else :
403- return result
404- finally :
405- server .cleanup_the_optimizer ()
400+ async def perform_function_optimization (params : FunctionOptimizationParams ) -> dict [str , str ]:
401+ with execution_context (task_id = params .task_id ):
402+ loop = asyncio .get_running_loop ()
403+ server .cancel_event = threading .Event ()
404+
405+ try :
406+ ctx = contextvars .copy_context ()
407+ return await loop .run_in_executor (None , ctx .run , sync_perform_optimization , params )
408+ except asyncio .CancelledError :
409+ server .cancel_event .set ()
410+ return get_cancelled_reponse ()
411+ finally :
412+ server .cleanup_the_optimizer ()
0 commit comments