77import threading
88from dataclasses import dataclass
99from pathlib import Path
10- from typing import TYPE_CHECKING , Optional
10+ from typing import TYPE_CHECKING , Optional , Union
1111
1212from codeflash .api .cfapi import get_codeflash_api_key , get_user_id
1313from codeflash .cli_cmds .cli import process_pyproject_config
1616 VsCodeSetupInfo ,
1717 configure_pyproject_toml ,
1818 create_empty_pyproject_toml ,
19+ create_find_common_tags_file ,
1920 get_formatter_cmds ,
2021 get_suggestions ,
2122 get_valid_subdirs ,
2223 is_valid_pyproject_toml ,
2324)
2425from codeflash .code_utils .git_utils import git_root_dir
26+ from codeflash .code_utils .git_worktree_utils import create_worktree_snapshot_commit
2527from codeflash .code_utils .shell_utils import save_api_key_to_rc
2628from codeflash .discovery .functions_to_optimize import (
2729 filter_functions ,
3941 from lsprotocol import types
4042
4143 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
44+ from codeflash .lsp .server import WrappedInitializationResultT
4245
4346
4447@dataclass
@@ -55,11 +58,15 @@ class FunctionOptimizationInitParams:
5558
5659@dataclass
5760class FunctionOptimizationParams :
58- textDocument : types .TextDocumentIdentifier # noqa: N815
5961 functionName : str # noqa: N815
6062 task_id : str
6163
6264
65+ @dataclass
66+ class DemoOptimizationParams :
67+ functionName : str # noqa: N815
68+
69+
6370@dataclass
6471class ProvideApiKeyParams :
6572 api_key : str
@@ -257,10 +264,8 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]:
257264
258265def _initialize_optimizer_if_api_key_is_valid (api_key : Optional [str ] = None ) -> dict [str , str ]:
259266 key_check_result = _check_api_key_validity (api_key )
260- if key_check_result .get ("status" ) != "success" :
261- return key_check_result
262-
263- _init ()
267+ if key_check_result .get ("status" ) == "success" :
268+ _init ()
264269 return key_check_result
265270
266271
@@ -303,8 +308,8 @@ def _init() -> Namespace:
303308def check_api_key (_params : any ) -> dict [str , str ]:
304309 try :
305310 return _initialize_optimizer_if_api_key_is_valid ()
306- except Exception :
307- return {"status" : "error" , "message" : "something went wrong while validating the api key" }
311+ except Exception as ex :
312+ return {"status" : "error" , "message" : "something went wrong while validating the api key " + str ( ex ) }
308313
309314
310315@server .feature ("provideApiKey" )
@@ -353,6 +358,56 @@ def cleanup_optimizer(_params: any) -> dict[str, str]:
353358 return {"status" : "success" }
354359
355360
361+ def _initialize_current_function_optimizer () -> Union [dict [str , str ], WrappedInitializationResultT ]:
362+ """Initialize the current function optimizer.
363+
364+ Returns:
365+ Union[dict[str, str], WrappedInitializationResultT]:
366+ error dict with status error,
367+ or a wrapped initializationresult if the optimizer is initialized.
368+
369+ """
370+ if not server .optimizer :
371+ return {"status" : "error" , "message" : "Optimizer not initialized yet." }
372+
373+ function_name = server .optimizer .args .function
374+ optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
375+
376+ if count == 0 :
377+ server .show_message_log (f"No optimizable functions found for { function_name } " , "Warning" )
378+ server .cleanup_the_optimizer ()
379+ return {"functionName" : function_name , "status" : "error" , "message" : "not found" , "args" : None }
380+
381+ fto = optimizable_funcs .popitem ()[1 ][0 ]
382+
383+ module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
384+ if not module_prep_result :
385+ return {
386+ "functionName" : function_name ,
387+ "status" : "error" ,
388+ "message" : "Failed to prepare module for optimization" ,
389+ }
390+
391+ validated_original_code , original_module_ast = module_prep_result
392+
393+ function_optimizer = server .optimizer .create_function_optimizer (
394+ fto ,
395+ function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
396+ original_module_ast = original_module_ast ,
397+ original_module_path = fto .file_path ,
398+ function_to_tests = {},
399+ )
400+
401+ server .optimizer .current_function_optimizer = function_optimizer
402+ if not function_optimizer :
403+ return {"functionName" : function_name , "status" : "error" , "message" : "No function optimizer found" }
404+
405+ initialization_result = function_optimizer .can_be_optimized ()
406+ if not is_successful (initialization_result ):
407+ return {"functionName" : function_name , "status" : "error" , "message" : initialization_result .failure ()}
408+ return initialization_result
409+
410+
356411@server .feature ("initializeFunctionOptimization" )
357412def initialize_function_optimization (params : FunctionOptimizationInitParams ) -> dict [str , str ]:
358413 with execution_context (task_id = getattr (params , "task_id" , None )):
@@ -377,52 +432,47 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) ->
377432 f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
378433 )
379434
380- optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
381-
382- if count == 0 :
383- server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
384- server .cleanup_the_optimizer ()
385- return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
386-
387- fto = optimizable_funcs .popitem ()[1 ][0 ]
388-
389- module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
390- if not module_prep_result :
391- return {
392- "functionName" : params .functionName ,
393- "status" : "error" ,
394- "message" : "Failed to prepare module for optimization" ,
395- }
396-
397- validated_original_code , original_module_ast = module_prep_result
398-
399- function_optimizer = server .optimizer .create_function_optimizer (
400- fto ,
401- function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
402- original_module_ast = original_module_ast ,
403- original_module_path = fto .file_path ,
404- function_to_tests = {},
405- )
406-
407- server .optimizer .current_function_optimizer = function_optimizer
408- if not function_optimizer :
409- return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
410-
411- initialization_result = function_optimizer .can_be_optimized ()
412- if not is_successful (initialization_result ):
413- return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
435+ initialization_result = _initialize_current_function_optimizer ()
436+ if isinstance (initialization_result , dict ):
437+ return initialization_result
414438
415439 server .current_optimization_init_result = initialization_result .unwrap ()
416440 server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
417441
418- files = [function_optimizer . function_to_optimize . file_path ]
442+ files = [document . path ]
419443
420444 _ , _ , original_helpers = server .current_optimization_init_result
421445 files .extend ([str (helper_path ) for helper_path in original_helpers ])
422446
423447 return {"functionName" : params .functionName , "status" : "success" , "files_inside_context" : files }
424448
425449
450+ @server .feature ("startDemoOptimization" )
451+ async def start_demo_optimization (params : DemoOptimizationParams ) -> dict [str , str ]:
452+ try :
453+ _init ()
454+ # start by creating the worktree so that the demo file is not created in user workspace
455+ server .optimizer .worktree_mode ()
456+ file_path = create_find_common_tags_file (server .args , params .functionName + ".py" )
457+ # commit the new file for diff generation later
458+ create_worktree_snapshot_commit (server .optimizer .current_worktree , "added sample optimization file" )
459+
460+ server .optimizer .args .file = file_path
461+ server .optimizer .args .function = params .functionName
462+ server .optimizer .args .previous_checkpoint_functions = False
463+
464+ initialization_result = _initialize_current_function_optimizer ()
465+ if isinstance (initialization_result , dict ):
466+ return initialization_result
467+
468+ server .current_optimization_init_result = initialization_result .unwrap ()
469+ return await perform_function_optimization (
470+ FunctionOptimizationParams (functionName = params .functionName , task_id = None )
471+ )
472+ finally :
473+ server .cleanup_the_optimizer ()
474+
475+
426476@server .feature ("performFunctionOptimization" )
427477async def perform_function_optimization (params : FunctionOptimizationParams ) -> dict [str , str ]:
428478 with execution_context (task_id = getattr (params , "task_id" , None )):
0 commit comments