77import threading
88from dataclasses import dataclass
99from pathlib import Path
10- from typing import TYPE_CHECKING , Optional
10+ from typing import TYPE_CHECKING , Optional , Union
1111
1212from codeflash .api .cfapi import get_codeflash_api_key , get_user_id
1313from codeflash .cli_cmds .cli import process_pyproject_config
1616 VsCodeSetupInfo ,
1717 configure_pyproject_toml ,
1818 create_empty_pyproject_toml ,
19+ create_find_common_tags_file ,
1920 get_formatter_cmds ,
2021 get_suggestions ,
2122 get_valid_subdirs ,
2223 is_valid_pyproject_toml ,
2324)
2425from codeflash .code_utils .git_utils import git_root_dir
26+ from codeflash .code_utils .git_worktree_utils import create_worktree_snapshot_commit
2527from codeflash .code_utils .shell_utils import save_api_key_to_rc
2628from codeflash .discovery .functions_to_optimize import (
2729 filter_functions ,
3840 from lsprotocol import types
3941
4042 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
43+ from codeflash .lsp .server import WrappedInitializationResultT
4144
4245
4346@dataclass
@@ -54,11 +57,15 @@ class FunctionOptimizationInitParams:
5457
5558@dataclass
5659class FunctionOptimizationParams :
57- textDocument : types .TextDocumentIdentifier # noqa: N815
5860 functionName : str # noqa: N815
5961 task_id : str
6062
6163
64+ @dataclass
65+ class DemoOptimizationParams :
66+ functionName : str # noqa: N815
67+
68+
6269@dataclass
6370class ProvideApiKeyParams :
6471 api_key : str
@@ -256,10 +263,8 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]:
256263
257264def _initialize_optimizer_if_api_key_is_valid (api_key : Optional [str ] = None ) -> dict [str , str ]:
258265 key_check_result = _check_api_key_validity (api_key )
259- if key_check_result .get ("status" ) != "success" :
260- return key_check_result
261-
262- _init ()
266+ if key_check_result .get ("status" ) == "success" :
267+ _init ()
263268 return key_check_result
264269
265270
@@ -351,6 +356,56 @@ def cleanup_optimizer(_params: any) -> dict[str, str]:
351356 return {"status" : "success" }
352357
353358
359+ def _initialize_current_function_optimizer () -> Union [dict [str , str ], WrappedInitializationResultT ]:
360+ """Initialize the current function optimizer.
361+
362+ Returns:
363+ Union[dict[str, str], WrappedInitializationResultT]:
364+ error dict with status error,
365+ or a wrapped initializationresult if the optimizer is initialized.
366+
367+ """
368+ if not server .optimizer :
369+ return {"status" : "error" , "message" : "Optimizer not initialized yet." }
370+
371+ function_name = server .optimizer .args .function
372+ optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
373+
374+ if count == 0 :
375+ server .show_message_log (f"No optimizable functions found for { function_name } " , "Warning" )
376+ server .cleanup_the_optimizer ()
377+ return {"functionName" : function_name , "status" : "error" , "message" : "not found" , "args" : None }
378+
379+ fto = optimizable_funcs .popitem ()[1 ][0 ]
380+
381+ module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
382+ if not module_prep_result :
383+ return {
384+ "functionName" : function_name ,
385+ "status" : "error" ,
386+ "message" : "Failed to prepare module for optimization" ,
387+ }
388+
389+ validated_original_code , original_module_ast = module_prep_result
390+
391+ function_optimizer = server .optimizer .create_function_optimizer (
392+ fto ,
393+ function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
394+ original_module_ast = original_module_ast ,
395+ original_module_path = fto .file_path ,
396+ function_to_tests = {},
397+ )
398+
399+ server .optimizer .current_function_optimizer = function_optimizer
400+ if not function_optimizer :
401+ return {"functionName" : function_name , "status" : "error" , "message" : "No function optimizer found" }
402+
403+ initialization_result = function_optimizer .can_be_optimized ()
404+ if not is_successful (initialization_result ):
405+ return {"functionName" : function_name , "status" : "error" , "message" : initialization_result .failure ()}
406+ return initialization_result
407+
408+
354409@server .feature ("initializeFunctionOptimization" )
355410def initialize_function_optimization (params : FunctionOptimizationInitParams ) -> dict [str , str ]:
356411 with execution_context (task_id = getattr (params , "task_id" , None )):
@@ -375,52 +430,47 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) ->
375430 f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
376431 )
377432
378- optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
379-
380- if count == 0 :
381- server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
382- server .cleanup_the_optimizer ()
383- return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
384-
385- fto = optimizable_funcs .popitem ()[1 ][0 ]
386-
387- module_prep_result = server .optimizer .prepare_module_for_optimization (fto .file_path )
388- if not module_prep_result :
389- return {
390- "functionName" : params .functionName ,
391- "status" : "error" ,
392- "message" : "Failed to prepare module for optimization" ,
393- }
394-
395- validated_original_code , original_module_ast = module_prep_result
396-
397- function_optimizer = server .optimizer .create_function_optimizer (
398- fto ,
399- function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
400- original_module_ast = original_module_ast ,
401- original_module_path = fto .file_path ,
402- function_to_tests = {},
403- )
404-
405- server .optimizer .current_function_optimizer = function_optimizer
406- if not function_optimizer :
407- return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
408-
409- initialization_result = function_optimizer .can_be_optimized ()
410- if not is_successful (initialization_result ):
411- return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
433+ initialization_result = _initialize_current_function_optimizer ()
434+ if isinstance (initialization_result , dict ):
435+ return initialization_result
412436
413437 server .current_optimization_init_result = initialization_result .unwrap ()
414438 server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
415439
416- files = [function_optimizer . function_to_optimize . file_path ]
440+ files = [document . path ]
417441
418442 _ , _ , original_helpers = server .current_optimization_init_result
419443 files .extend ([str (helper_path ) for helper_path in original_helpers ])
420444
421445 return {"functionName" : params .functionName , "status" : "success" , "files_inside_context" : files }
422446
423447
448+ @server .feature ("startDemoOptimization" )
449+ async def start_demo_optimization (params : DemoOptimizationParams ) -> dict [str , str ]:
450+ try :
451+ _init ()
452+ # start by creating the worktree so that the demo file is not created in user workspace
453+ server .optimizer .worktree_mode ()
454+ file_path = create_find_common_tags_file (server .args , params .functionName + ".py" )
455+ # commit the new file for diff generation later
456+ create_worktree_snapshot_commit (server .optimizer .current_worktree , "added sample optimization file" )
457+
458+ server .optimizer .args .file = file_path
459+ server .optimizer .args .function = params .functionName
460+ server .optimizer .args .previous_checkpoint_functions = False
461+
462+ initialization_result = _initialize_current_function_optimizer ()
463+ if isinstance (initialization_result , dict ):
464+ return initialization_result
465+
466+ server .current_optimization_init_result = initialization_result .unwrap ()
467+ return await perform_function_optimization (
468+ FunctionOptimizationParams (functionName = params .functionName , task_id = None )
469+ )
470+ finally :
471+ server .cleanup_the_optimizer ()
472+
473+
424474@server .feature ("performFunctionOptimization" )
425475async def perform_function_optimization (params : FunctionOptimizationParams ) -> dict [str , str ]:
426476 with execution_context (task_id = getattr (params , "task_id" , None )):
0 commit comments