diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 28a6ec655..57a65723a 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -131,10 +131,11 @@ def get_optimizable_functions( return path_to_qualified_names -def _find_pyproject_toml(workspace_path: str) -> Path | None: +def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]: workspace_path_obj = Path(workspace_path) max_depth = 2 base_depth = len(workspace_path_obj.parts) + top_level_pyproject = None for root, dirs, files in os.walk(workspace_path_obj): depth = len(Path(root).parts) - base_depth @@ -145,32 +146,39 @@ def _find_pyproject_toml(workspace_path: str) -> Path | None: if "pyproject.toml" in files: file_path = Path(root) / "pyproject.toml" + if depth == 0: + top_level_pyproject = file_path with file_path.open("r", encoding="utf-8", errors="ignore") as f: for line in f: if line.strip() == "[tool.codeflash]": - return file_path.resolve() - return None + return file_path.resolve(), True + return top_level_pyproject, False # should be called the first thing to initialize and validate the project @server.feature("initProject") -def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: +def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: # noqa: PLR0911 from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml - pyproject_toml_path: Path | None = getattr(params, "config_file", None) + pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None) - if server.args is None: - if pyproject_toml_path is not None: - # if there is a config file provided use it + if pyproject_toml_path is not None: + # if there is a config file provided use it + server.prepare_optimizer_arguments(pyproject_toml_path) + else: + # otherwise look for it + pyproject_toml_path, has_codeflash_config = _find_pyproject_toml(params.root_path_abs) + if pyproject_toml_path and has_codeflash_config: + server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info") server.prepare_optimizer_arguments(pyproject_toml_path) + elif pyproject_toml_path and not has_codeflash_config: + return { + "status": "error", + "message": "pyproject.toml found in workspace, but no codeflash config.", + "pyprojectPath": pyproject_toml_path, + } else: - # otherwise look for it - pyproject_toml_path = _find_pyproject_toml(params.root_path_abs) - server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info") - if pyproject_toml_path: - server.prepare_optimizer_arguments(pyproject_toml_path) - else: - return {"status": "error", "message": "No pyproject.toml found in workspace."} + return {"status": "error", "message": "No pyproject.toml found in workspace."} # since we are using worktrees, optimization diffs are generated with respect to the root of the repo. root = str(git_root_dir()) @@ -187,10 +195,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) config = is_valid_pyproject_toml(pyproject_toml_path) if config is None: server.show_message_log("pyproject.toml is not valid", "Error") - return { - "status": "error", - "message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions, - } + return {"status": "error", "message": "not valid", "pyprojectPath": pyproject_toml_path} args = process_args(server) repo = git.Repo(args.module_root, search_parent_directories=True)