Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def get_optimizable_functions(
return path_to_qualified_names


def _find_pyproject_toml(workspace_path: str) -> Path | None:
def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
workspace_path_obj = Path(workspace_path)
max_depth = 2
base_depth = len(workspace_path_obj.parts)
top_level_pyproject = None

for root, dirs, files in os.walk(workspace_path_obj):
depth = len(Path(root).parts) - base_depth
Expand All @@ -145,32 +146,39 @@ def _find_pyproject_toml(workspace_path: str) -> Path | None:

if "pyproject.toml" in files:
file_path = Path(root) / "pyproject.toml"
if depth == 0:
top_level_pyproject = file_path
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
for line in f:
if line.strip() == "[tool.codeflash]":
return file_path.resolve()
return None
return file_path.resolve(), True
return top_level_pyproject, False


# should be called the first thing to initialize and validate the project
@server.feature("initProject")
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]: # noqa: PLR0911
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml

pyproject_toml_path: Path | None = getattr(params, "config_file", None)
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)

if server.args is None:
if pyproject_toml_path is not None:
# if there is a config file provided use it
if pyproject_toml_path is not None:
# if there is a config file provided use it
server.prepare_optimizer_arguments(pyproject_toml_path)
else:
# otherwise look for it
pyproject_toml_path, has_codeflash_config = _find_pyproject_toml(params.root_path_abs)
if pyproject_toml_path and has_codeflash_config:
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
server.prepare_optimizer_arguments(pyproject_toml_path)
elif pyproject_toml_path and not has_codeflash_config:
return {
"status": "error",
"message": "pyproject.toml found in workspace, but no codeflash config.",
"pyprojectPath": pyproject_toml_path,
}
else:
# otherwise look for it
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
if pyproject_toml_path:
server.prepare_optimizer_arguments(pyproject_toml_path)
else:
return {"status": "error", "message": "No pyproject.toml found in workspace."}
return {"status": "error", "message": "No pyproject.toml found in workspace."}

# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
root = str(git_root_dir())
Expand All @@ -187,10 +195,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
config = is_valid_pyproject_toml(pyproject_toml_path)
if config is None:
server.show_message_log("pyproject.toml is not valid", "Error")
return {
"status": "error",
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
}
return {"status": "error", "message": "not valid", "pyprojectPath": pyproject_toml_path}

args = process_args(server)
repo = git.Repo(args.module_root, search_parent_directories=True)
Expand Down
Loading