@@ -50,6 +50,13 @@ class ProvideApiKeyParams:
5050 api_key : str
5151
5252
53+ @dataclass
54+ class ValidateProjectParams :
55+ root_path_abs : str
56+ config_file : Optional [str ] = None
57+ skip_validation : bool = False
58+
59+
5360@dataclass
5461class OnPatchAppliedParams :
5562 patch_id : str
@@ -160,17 +167,60 @@ def initialize_function_optimization(
160167 return {"functionName" : params .functionName , "status" : "success" }
161168
162169
163- @server .feature ("validateProject" )
164- def validate_project (server : CodeflashLanguageServer , _params : FunctionOptimizationParams ) -> dict [str , str ]:
170+ def _find_pyproject_toml (workspace_path : str ) -> Path | None :
171+ workspace_path_obj = Path (workspace_path )
172+ max_depth = 2
173+ base_depth = len (workspace_path_obj .parts )
174+
175+ for root , dirs , files in os .walk (workspace_path_obj ):
176+ depth = len (Path (root ).parts ) - base_depth
177+ if depth > max_depth :
178+ # stop going deeper into this branch
179+ dirs .clear ()
180+ continue
181+
182+ if "pyproject.toml" in files :
183+ file_path = Path (root ) / "pyproject.toml"
184+ with file_path .open ("r" , encoding = "utf-8" , errors = "ignore" ) as f :
185+ for line in f :
186+ if line .strip () == "[tool.codeflash]" :
187+ return file_path .resolve ()
188+ return None
189+
190+
191+ # should be called the first thing to initialize and validate the project
192+ @server .feature ("initProject" )
193+ def init_project (server : CodeflashLanguageServer , params : ValidateProjectParams ) -> dict [str , str ]:
165194 from codeflash .cli_cmds .cmd_init import is_valid_pyproject_toml
166195
196+ pyproject_toml_path : Path | None = getattr (params , "config_file" , None )
197+
198+ if server .args is None :
199+ if pyproject_toml_path is not None :
200+ # if there is a config file provided use it
201+ server .prepare_optimizer_arguments (pyproject_toml_path )
202+ else :
203+ # otherwise look for it
204+ pyproject_toml_path = _find_pyproject_toml (params .root_path_abs )
205+ server .show_message_log (f"Found pyproject.toml at: { pyproject_toml_path } " , "Info" )
206+ if pyproject_toml_path :
207+ server .prepare_optimizer_arguments (pyproject_toml_path )
208+ else :
209+ return {
210+ "status" : "error" ,
211+ "message" : "No pyproject.toml found in workspace." ,
212+ } # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
213+
214+ if getattr (params , "skip_validation" , False ):
215+ return {"status" : "success" , "moduleRoot" : server .args .module_root , "pyprojectPath" : pyproject_toml_path }
216+
167217 server .show_message_log ("Validating project..." , "Info" )
168- config = is_valid_pyproject_toml (server . args . config_file )
218+ config = is_valid_pyproject_toml (pyproject_toml_path )
169219 if config is None :
170220 server .show_message_log ("pyproject.toml is not valid" , "Error" )
171221 return {
172222 "status" : "error" ,
173- "message" : "pyproject.toml is not valid" , # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
223+ "message" : "pyproject.toml is not valid" , # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
174224 }
175225
176226 args = process_args (server )
@@ -183,7 +233,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
183233 except Exception :
184234 return {"status" : "error" , "message" : "Repository has no commits (unborn HEAD)" }
185235
186- return {"status" : "success" , "moduleRoot" : args .module_root }
236+ return {"status" : "success" , "moduleRoot" : args .module_root , "pyprojectPath" : pyproject_toml_path }
187237
188238
189239def _initialize_optimizer_if_api_key_is_valid (
0 commit comments