1212from codeflash .api .cfapi import get_codeflash_api_key , get_user_id
1313from codeflash .cli_cmds .cli import process_pyproject_config
1414from codeflash .cli_cmds .console import code_print
15- from codeflash .code_utils .git_worktree_utils import (
16- create_diff_patch_from_worktree ,
17- get_patches_metadata ,
18- overwrite_patch_metadata ,
19- )
15+ from codeflash .code_utils .git_utils import git_root_dir
16+ from codeflash .code_utils .git_worktree_utils import create_diff_patch_from_worktree
2017from codeflash .code_utils .shell_utils import save_api_key_to_rc
2118from codeflash .discovery .functions_to_optimize import (
2219 filter_functions ,
@@ -39,10 +36,17 @@ class OptimizableFunctionsParams:
3936 textDocument : types .TextDocumentIdentifier # noqa: N815
4037
4138
39+ @dataclass
40+ class FunctionOptimizationInitParams :
41+ textDocument : types .TextDocumentIdentifier # noqa: N815
42+ functionName : str # noqa: N815
43+
44+
4245@dataclass
4346class FunctionOptimizationParams :
4447 textDocument : types .TextDocumentIdentifier # noqa: N815
4548 functionName : str # noqa: N815
49+ task_id : str
4650
4751
4852@dataclass
@@ -59,7 +63,7 @@ class ValidateProjectParams:
5963
6064@dataclass
6165class OnPatchAppliedParams :
62- patch_id : str
66+ task_id : str
6367
6468
6569@dataclass
@@ -132,42 +136,6 @@ def get_optimizable_functions(
132136 return path_to_qualified_names
133137
134138
135- @server .feature ("initializeFunctionOptimization" )
136- def initialize_function_optimization (
137- server : CodeflashLanguageServer , params : FunctionOptimizationParams
138- ) -> dict [str , str ]:
139- file_path = Path (uris .to_fs_path (params .textDocument .uri ))
140- server .show_message_log (f"Initializing optimization for function: { params .functionName } in { file_path } " , "Info" )
141-
142- if server .optimizer is None :
143- _initialize_optimizer_if_api_key_is_valid (server )
144-
145- server .optimizer .worktree_mode ()
146-
147- original_args , _ = server .optimizer .original_args_and_test_cfg
148-
149- server .optimizer .args .function = params .functionName
150- original_relative_file_path = file_path .relative_to (original_args .project_root )
151- server .optimizer .args .file = server .optimizer .current_worktree / original_relative_file_path
152- server .optimizer .args .previous_checkpoint_functions = False
153-
154- server .show_message_log (
155- f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
156- )
157-
158- optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
159-
160- if count == 0 :
161- server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
162- server .cleanup_the_optimizer ()
163- return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
164-
165- fto = optimizable_funcs .popitem ()[1 ][0 ]
166- server .optimizer .current_function_being_optimized = fto
167- server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
168- return {"functionName" : params .functionName , "status" : "success" }
169-
170-
171139def _find_pyproject_toml (workspace_path : str ) -> Path | None :
172140 workspace_path_obj = Path (workspace_path )
173141 max_depth = 2
@@ -207,13 +175,18 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
207175 if pyproject_toml_path :
208176 server .prepare_optimizer_arguments (pyproject_toml_path )
209177 else :
210- return {
211- "status" : "error" ,
212- "message" : "No pyproject.toml found in workspace." ,
213- } # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
178+ return {"status" : "error" , "message" : "No pyproject.toml found in workspace." }
179+
180+ # since we are using worktrees, optimization diffs are generated with respect to the root of the repo, also the args.project_root is set to the root of the repo when creating a worktree
181+ root = str ( git_root_dir ())
214182
215183 if getattr (params , "skip_validation" , False ):
216- return {"status" : "success" , "moduleRoot" : server .args .module_root , "pyprojectPath" : pyproject_toml_path }
184+ return {
185+ "status" : "success" ,
186+ "moduleRoot" : server .args .module_root ,
187+ "pyprojectPath" : pyproject_toml_path ,
188+ "root" : root ,
189+ }
217190
218191 server .show_message_log ("Validating project..." , "Info" )
219192 config = is_valid_pyproject_toml (pyproject_toml_path )
@@ -234,7 +207,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
234207 except Exception :
235208 return {"status" : "error" , "message" : "Repository has no commits (unborn HEAD)" }
236209
237- return {"status" : "success" , "moduleRoot" : args .module_root , "pyprojectPath" : pyproject_toml_path }
210+ return {"status" : "success" , "moduleRoot" : args .module_root , "pyprojectPath" : pyproject_toml_path , "root" : root }
238211
239212
240213def _initialize_optimizer_if_api_key_is_valid (
@@ -296,78 +269,85 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
296269 return {"status" : "error" , "message" : "something went wrong while saving the api key" }
297270
298271
299- @server .feature ("retrieveSuccessfulOptimizations" )
300- def retrieve_successful_optimizations (_server : CodeflashLanguageServer , _params : any ) -> dict [str , str ]:
301- metadata = get_patches_metadata ()
302- return {"status" : "success" , "patches" : metadata ["patches" ]}
272+ @server .feature ("initializeFunctionOptimization" )
273+ def initialize_function_optimization (
274+ server : CodeflashLanguageServer , params : FunctionOptimizationInitParams
275+ ) -> dict [str , str ]:
276+ file_path = Path (uris .to_fs_path (params .textDocument .uri ))
277+ server .show_message_log (f"Initializing optimization for function: { params .functionName } in { file_path } " , "Info" )
278+
279+ if server .optimizer is None :
280+ _initialize_optimizer_if_api_key_is_valid (server )
281+
282+ server .optimizer .worktree_mode ()
303283
284+ original_args , _ = server .optimizer .original_args_and_test_cfg
304285
305- @ server .feature ( "onPatchApplied" )
306- def on_patch_applied ( _server : CodeflashLanguageServer , params : OnPatchAppliedParams ) -> dict [ str , str ]:
307- # first remove the patch from the metadata
308- metadata = get_patches_metadata ()
286+ server .optimizer . args . function = params . functionName
287+ original_relative_file_path = file_path . relative_to ( original_args . project_root )
288+ server . optimizer . args . file = server . optimizer . current_worktree / original_relative_file_path
289+ server . optimizer . args . previous_checkpoint_functions = False
309290
310- deleted_patch_file = None
311- new_patches = []
312- for patch in metadata ["patches" ]:
313- if patch ["id" ] == params .patch_id :
314- deleted_patch_file = patch ["patch_path" ]
315- continue
316- new_patches .append (patch )
291+ server .show_message_log (
292+ f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
293+ )
317294
318- # then remove the patch file
319- if deleted_patch_file :
320- overwrite_patch_metadata (new_patches )
321- patch_path = Path (deleted_patch_file )
322- patch_path .unlink (missing_ok = True )
323- return {"status" : "success" }
324- return {"status" : "error" , "message" : "Patch not found" }
295+ optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
325296
297+ if count == 0 :
298+ server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
299+ server .cleanup_the_optimizer ()
300+ return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
326301
327- @server .feature ("performFunctionOptimization" )
328- @server .thread ()
329- def perform_function_optimization ( # noqa: PLR0911
330- server : CodeflashLanguageServer , params : FunctionOptimizationParams
331- ) -> dict [str , str ]:
332- try :
333- server .show_message_log (f"Starting optimization for function: { params .functionName } " , "Info" )
334- current_function = server .optimizer .current_function_being_optimized
302+ fto = optimizable_funcs .popitem ()[1 ][0 ]
335303
336- if not current_function :
337- server . show_message_log ( f"No current function being optimized for { params . functionName } " , "Error" )
338- return {
339- "functionName" : params .functionName ,
340- "status" : "error" ,
341- "message" : "No function currently being optimized " ,
342- }
304+ module_prep_result = server . optimizer . prepare_module_for_optimization ( fto . file_path )
305+ if not module_prep_result :
306+ return {
307+ "functionName" : params .functionName ,
308+ "status" : "error" ,
309+ "message" : "Failed to prepare module for optimization " ,
310+ }
343311
344- module_prep_result = server .optimizer .prepare_module_for_optimization (current_function .file_path )
345- if not module_prep_result :
346- return {
347- "functionName" : params .functionName ,
348- "status" : "error" ,
349- "message" : "Failed to prepare module for optimization" ,
350- }
312+ validated_original_code , original_module_ast = module_prep_result
351313
352- validated_original_code , original_module_ast = module_prep_result
314+ function_optimizer = server .optimizer .create_function_optimizer (
315+ fto ,
316+ function_to_optimize_source_code = validated_original_code [fto .file_path ].source_code ,
317+ original_module_ast = original_module_ast ,
318+ original_module_path = fto .file_path ,
319+ function_to_tests = {},
320+ )
353321
354- function_optimizer = server .optimizer .create_function_optimizer (
355- current_function ,
356- function_to_optimize_source_code = validated_original_code [current_function .file_path ].source_code ,
357- original_module_ast = original_module_ast ,
358- original_module_path = current_function .file_path ,
359- function_to_tests = {},
360- )
322+ server .optimizer .current_function_optimizer = function_optimizer
323+ if not function_optimizer :
324+ return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
325+
326+ initialization_result = function_optimizer .can_be_optimized ()
327+ if not is_successful (initialization_result ):
328+ return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
329+
330+ server .current_optimization_init_result = initialization_result .unwrap ()
331+ server .show_message_log (f"Successfully initialized optimization for { params .functionName } " , "Info" )
332+
333+ files = [function_optimizer .function_to_optimize .file_path ]
334+
335+ _ , _ , original_helpers = server .current_optimization_init_result
336+ files .extend ([str (helper_path ) for helper_path in original_helpers ])
361337
362- server .optimizer .current_function_optimizer = function_optimizer
363- if not function_optimizer :
364- return {"functionName" : params .functionName , "status" : "error" , "message" : "No function optimizer found" }
338+ return {"functionName" : params .functionName , "status" : "success" , "files_inside_context" : files }
365339
366- initialization_result = function_optimizer .can_be_optimized ()
367- if not is_successful (initialization_result ):
368- return {"functionName" : params .functionName , "status" : "error" , "message" : initialization_result .failure ()}
369340
370- should_run_experiment , code_context , original_helper_code = initialization_result .unwrap ()
341+ @server .feature ("performFunctionOptimization" )
342+ @server .thread ()
343+ def perform_function_optimization (
344+ server : CodeflashLanguageServer , params : FunctionOptimizationParams
345+ ) -> dict [str , str ]:
346+ try :
347+ server .show_message_log (f"Starting optimization for function: { params .functionName } " , "Info" )
348+ should_run_experiment , code_context , original_helper_code = server .current_optimization_init_result
349+ function_optimizer = server .optimizer .current_function_optimizer
350+ current_function = function_optimizer .function_to_optimize
371351
372352 code_print (
373353 code_context .read_writable_code .flat ,
@@ -447,20 +427,8 @@ def perform_function_optimization( # noqa: PLR0911
447427
448428 speedup = original_code_baseline .runtime / best_optimization .runtime
449429
450- # get the original file path in the actual project (not in the worktree)
451- original_args , _ = server .optimizer .original_args_and_test_cfg
452- relative_file_path = current_function .file_path .relative_to (server .optimizer .current_worktree )
453- original_file_path = Path (original_args .project_root / relative_file_path ).resolve ()
454-
455- metadata = create_diff_patch_from_worktree (
456- server .optimizer .current_worktree ,
457- relative_file_paths ,
458- metadata_input = {
459- "fto_name" : function_to_optimize_qualified_name ,
460- "explanation" : best_optimization .explanation_v2 ,
461- "file_path" : str (original_file_path ),
462- "speedup" : speedup ,
463- },
430+ patch_path = create_diff_patch_from_worktree (
431+ server .optimizer .current_worktree , relative_file_paths , function_to_optimize_qualified_name
464432 )
465433
466434 server .show_message_log (f"Optimization completed for { params .functionName } with { speedup :.2f} x speedup" , "Info" )
@@ -470,8 +438,8 @@ def perform_function_optimization( # noqa: PLR0911
470438 "status" : "success" ,
471439 "message" : "Optimization completed successfully" ,
472440 "extra" : f"Speedup: { speedup :.2f} x faster" ,
473- "patch_file" : metadata [ " patch_path" ] ,
474- "patch_id " : metadata [ "id" ] ,
441+ "patch_file" : patch_path ,
442+ "task_id " : params . task_id ,
475443 "explanation" : best_optimization .explanation_v2 ,
476444 }
477445 finally :
0 commit comments