Skip to content
4 changes: 4 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version


Expand Down Expand Up @@ -211,6 +212,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
if args.benchmarks_root:
args.benchmarks_root = Path(args.benchmarks_root).resolve()
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
if is_LSP_enabled():
args.all = None
return args
return handle_optimize_all_arg_parsing(args)


Expand Down
32 changes: 23 additions & 9 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,22 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)


def is_valid_pyproject_toml(pyproject_toml_path: Path) -> dict[str, Any] | None:
if not pyproject_toml_path.exists():
return None
try:
config, _ = parse_config_file(pyproject_toml_path)
except Exception:
return None

if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
return None
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
return None

return config


def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.

Expand All @@ -163,16 +179,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
from rich.prompt import Confirm

pyproject_toml_path = Path.cwd() / "pyproject.toml"
if not pyproject_toml_path.exists():
return True, None
try:
config, config_file_path = parse_config_file(pyproject_toml_path)
except Exception:
return True, None

if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
return True, None
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
config = is_valid_pyproject_toml(pyproject_toml_path)
if config is None:
return True, None

return Confirm.ask(
Expand Down Expand Up @@ -968,6 +977,11 @@ def install_github_app(git_remote: str) -> None:
except git.InvalidGitRepositoryError:
click.echo("Skipping GitHub app installation because you're not in a git repository.")
return

if git_remote not in get_git_remotes(git_repo):
click.echo(f"Skipping GitHub app installation, remote ({git_remote}) does not exist in this repository.")
return

owner, repo = get_repo_owner_and_name(git_repo, git_remote)

if is_github_app_installed_on_repo(owner, repo, suppress_errors=True):
Expand Down
10 changes: 7 additions & 3 deletions codeflash/code_utils/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:

def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.commit("-am", commit_message, "--no-verify")
repository.git.add(".")
repository.git.commit("-m", commit_message, "--no-verify")


def create_detached_worktree(module_root: Path) -> Optional[Path]:
Expand All @@ -218,7 +219,10 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:

# Get uncommitted diff from the original repo
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
uni_diff_text = repository.git.diff(
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
)

if not uni_diff_text.strip():
logger.info("No uncommitted changes to copy to worktree.")
Expand All @@ -234,7 +238,7 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
# Apply the patch inside the worktree
try:
subprocess.run(
["git", "apply", "--ignore-space-change", "--ignore-whitespace", patch_path],
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
cwd=worktree_dir,
check=True,
)
Expand Down
9 changes: 9 additions & 0 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import FunctionParent
from codeflash.telemetry.posthog_cf import ph

Expand Down Expand Up @@ -168,6 +169,7 @@ def get_functions_to_optimize(
)
functions: dict[str, list[FunctionToOptimize]]
trace_file_path: Path | None = None
is_lsp = is_LSP_enabled()
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=SyntaxWarning)
if optimize_all:
Expand All @@ -185,6 +187,8 @@ def get_functions_to_optimize(
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
if is_lsp:
return functions, 0, None
exit_with_message(
"Function name should be in the format 'function_name' or 'class_name.function_name'"
)
Expand All @@ -200,6 +204,8 @@ def get_functions_to_optimize(
):
found_function = fn
if found_function is None:
if is_lsp:
return functions, 0, None
exit_with_message(
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
)
Expand Down Expand Up @@ -470,6 +476,9 @@ def was_function_previously_optimized(
Tuple of (filtered_functions_dict, remaining_count)

"""
if is_LSP_enabled():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: could this raise questions on underterministic nature of optimizations if the function has is same.

return False

# Check optimization status if repository info is provided
# already_optimized_count = 0
try:
Expand Down
66 changes: 53 additions & 13 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
from pathlib import Path
from typing import TYPE_CHECKING

import git
from pygls import uris

from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.code_utils.git_utils import create_diff_patch_from_worktree
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff
from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol

if TYPE_CHECKING:
from argparse import Namespace

from lsprotocol import types


Expand Down Expand Up @@ -85,9 +89,12 @@ def initialize_function_optimization(
) -> dict[str, str]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")

if server.optimizer is None:
_initialize_optimizer_if_valid(server)
_initialize_optimizer_if_api_key_is_valid(server)

server.optimizer.worktree_mode()

original_args, _ = server.optimizer.original_args_and_test_cfg

server.optimizer.args.function = params.functionName
Expand All @@ -99,15 +106,12 @@ def initialize_function_optimization(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)

optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
if not optimizable_funcs:
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()

if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
return {
"functionName": params.functionName,
"status": "error",
"message": "function is no found or not optimizable",
"args": None,
}
cleanup_the_optimizer(server)
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}

fto = optimizable_funcs.popitem()[1][0]
server.optimizer.current_function_being_optimized = fto
Expand All @@ -129,7 +133,33 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt
return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests}


def _initialize_optimizer_if_valid(server: CodeflashLanguageServer) -> dict[str, str]:
@server.feature("validateProject")
def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]:
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml

server.show_message_log("Validating project...", "Info")
config = is_valid_pyproject_toml(server.args.config_file)
if config is None:
server.show_message_log("pyproject.toml is not valid", "Error")
return {
"status": "error",
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
}

args = process_args(server)
repo = git.Repo(args.module_root, search_parent_directories=True)
if repo.bare:
return {"status": "error", "message": "Repository is in bare state"}

try:
_ = repo.head.commit
except Exception:
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}

return {"status": "success"}


def _initialize_optimizer_if_api_key_is_valid(server: CodeflashLanguageServer) -> dict[str, str]:
user_id = get_user_id()
if user_id is None:
return {"status": "error", "message": "api key not found or invalid"}
Expand All @@ -140,14 +170,24 @@ def _initialize_optimizer_if_valid(server: CodeflashLanguageServer) -> dict[str,

from codeflash.optimization.optimizer import Optimizer

server.optimizer = Optimizer(server.args)
new_args = process_args(server)
server.optimizer = Optimizer(new_args)
return {"status": "success", "user_id": user_id}


def process_args(server: CodeflashLanguageServer) -> Namespace:
if server.args_processed_before:
return server.args
new_args = process_pyproject_config(server.args)
server.args = new_args
server.args_processed_before = True
return new_args


@server.feature("apiKeyExistsAndValid")
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
try:
return _initialize_optimizer_if_valid(server)
return _initialize_optimizer_if_api_key_is_valid(server)
except Exception:
return {"status": "error", "message": "something went wrong while validating the api key"}

Expand All @@ -167,7 +207,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
get_codeflash_api_key.cache_clear()
get_user_id.cache_clear()

init_result = _initialize_optimizer_if_valid(server)
init_result = _initialize_optimizer_if_api_key_is_valid(server)
if init_result["status"] == "error":
return {"status": "error", "message": "Api key is not valid"}

Expand Down
4 changes: 2 additions & 2 deletions codeflash/lsp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ class CodeflashLanguageServer(LanguageServer):
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(*args, **kwargs)
self.optimizer: Optimizer | None = None
self.args_processed_before: bool = False
self.args = None

def prepare_optimizer_arguments(self, config_file: Path) -> None:
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cli import parse_args

args = parse_args()
args.config_file = config_file
args.no_pr = True # LSP server should not create PRs
args.worktree = True
args = process_pyproject_config(args)
self.args = args
# avoid initializing the optimizer during initialization, because it can cause an error if the api key is invalid

Expand Down
2 changes: 2 additions & 0 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def worktree_mode(self) -> None:
return
self.current_worktree = worktree_dir
self.mutate_args_for_worktree_mode(worktree_dir)
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)

def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
saved_args = copy.deepcopy(self.args)
Expand Down
Loading