66from pathlib import Path
77from typing import TYPE_CHECKING
88
9+ import git
910from pygls import uris
1011
1112from codeflash .api .cfapi import get_codeflash_api_key , get_user_id
13+ from codeflash .cli_cmds .cli import process_pyproject_config
1214from codeflash .code_utils .git_utils import create_diff_patch_from_worktree
1315from codeflash .code_utils .shell_utils import save_api_key_to_rc
1416from codeflash .discovery .functions_to_optimize import filter_functions , get_functions_within_git_diff
1517from codeflash .either import is_successful
1618from codeflash .lsp .server import CodeflashLanguageServer , CodeflashLanguageServerProtocol
1719
1820if TYPE_CHECKING :
21+ from argparse import Namespace
22+
1923 from lsprotocol import types
2024
2125
@@ -85,9 +89,12 @@ def initialize_function_optimization(
8589) -> dict [str , str ]:
8690 file_path = Path (uris .to_fs_path (params .textDocument .uri ))
8791 server .show_message_log (f"Initializing optimization for function: { params .functionName } in { file_path } " , "Info" )
92+
8893 if server .optimizer is None :
89- _initialize_optimizer_if_valid (server )
94+ _initialize_optimizer_if_api_key_is_valid (server )
95+
9096 server .optimizer .worktree_mode ()
97+
9198 original_args , _ = server .optimizer .original_args_and_test_cfg
9299
93100 server .optimizer .args .function = params .functionName
@@ -99,15 +106,12 @@ def initialize_function_optimization(
99106 f"Args set - function: { server .optimizer .args .function } , file: { server .optimizer .args .file } " , "Info"
100107 )
101108
102- optimizable_funcs , _ , _ = server .optimizer .get_optimizable_functions ()
103- if not optimizable_funcs :
109+ optimizable_funcs , count , _ = server .optimizer .get_optimizable_functions ()
110+
111+ if count == 0 :
104112 server .show_message_log (f"No optimizable functions found for { params .functionName } " , "Warning" )
105- return {
106- "functionName" : params .functionName ,
107- "status" : "error" ,
108- "message" : "function is no found or not optimizable" ,
109- "args" : None ,
110- }
113+ cleanup_the_optimizer (server )
114+ return {"functionName" : params .functionName , "status" : "error" , "message" : "not found" , "args" : None }
111115
112116 fto = optimizable_funcs .popitem ()[1 ][0 ]
113117 server .optimizer .current_function_being_optimized = fto
@@ -129,7 +133,33 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt
129133 return {"functionName" : params .functionName , "status" : "success" , "discovered_tests" : num_discovered_tests }
130134
131135
132- def _initialize_optimizer_if_valid (server : CodeflashLanguageServer ) -> dict [str , str ]:
136+ @server .feature ("validateProject" )
137+ def validate_project (server : CodeflashLanguageServer , _params : FunctionOptimizationParams ) -> dict [str , str ]:
138+ from codeflash .cli_cmds .cmd_init import is_valid_pyproject_toml
139+
140+ server .show_message_log ("Validating project..." , "Info" )
141+ config = is_valid_pyproject_toml (server .args .config_file )
142+ if config is None :
143+ server .show_message_log ("pyproject.toml is not valid" , "Error" )
144+ return {
145+ "status" : "error" ,
146+ "message" : "pyproject.toml is not valid" , # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
147+ }
148+
149+ args = process_args (server )
150+ repo = git .Repo (args .module_root , search_parent_directories = True )
151+ if repo .bare :
152+ return {"status" : "error" , "message" : "Repository is in bare state" }
153+
154+ try :
155+ _ = repo .head .commit
156+ except Exception :
157+ return {"status" : "error" , "message" : "Repository has no commits (unborn HEAD)" }
158+
159+ return {"status" : "success" }
160+
161+
162+ def _initialize_optimizer_if_api_key_is_valid (server : CodeflashLanguageServer ) -> dict [str , str ]:
133163 user_id = get_user_id ()
134164 if user_id is None :
135165 return {"status" : "error" , "message" : "api key not found or invalid" }
@@ -140,14 +170,24 @@ def _initialize_optimizer_if_valid(server: CodeflashLanguageServer) -> dict[str,
140170
141171 from codeflash .optimization .optimizer import Optimizer
142172
143- server .optimizer = Optimizer (server .args )
173+ new_args = process_args (server )
174+ server .optimizer = Optimizer (new_args )
144175 return {"status" : "success" , "user_id" : user_id }
145176
146177
178+ def process_args (server : CodeflashLanguageServer ) -> Namespace :
179+ if server .args_processed_before :
180+ return server .args
181+ new_args = process_pyproject_config (server .args )
182+ server .args = new_args
183+ server .args_processed_before = True
184+ return new_args
185+
186+
147187@server .feature ("apiKeyExistsAndValid" )
148188def check_api_key (server : CodeflashLanguageServer , _params : any ) -> dict [str , str ]:
149189 try :
150- return _initialize_optimizer_if_valid (server )
190+ return _initialize_optimizer_if_api_key_is_valid (server )
151191 except Exception :
152192 return {"status" : "error" , "message" : "something went wrong while validating the api key" }
153193
@@ -167,7 +207,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
167207 get_codeflash_api_key .cache_clear ()
168208 get_user_id .cache_clear ()
169209
170- init_result = _initialize_optimizer_if_valid (server )
210+ init_result = _initialize_optimizer_if_api_key_is_valid (server )
171211 if init_result ["status" ] == "error" :
172212 return {"status" : "error" , "message" : "Api key is not valid" }
173213
0 commit comments