Skip to content
4 changes: 4 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version


Expand Down Expand Up @@ -211,6 +212,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
if args.benchmarks_root:
args.benchmarks_root = Path(args.benchmarks_root).resolve()
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
if is_LSP_enabled():
args.all = None
return args
return handle_optimize_all_arg_parsing(args)


Expand Down
32 changes: 23 additions & 9 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,22 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)


def is_valid_pyproject_toml(pyproject_toml_path: Path) -> dict[str, Any] | None:
if not pyproject_toml_path.exists():
return None
try:
config, _ = parse_config_file(pyproject_toml_path)
except Exception:
return None

if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
return None
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
return None

return config


def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.

Expand All @@ -163,16 +179,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
from rich.prompt import Confirm

pyproject_toml_path = Path.cwd() / "pyproject.toml"
if not pyproject_toml_path.exists():
return True, None
try:
config, config_file_path = parse_config_file(pyproject_toml_path)
except Exception:
return True, None

if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
return True, None
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
config = is_valid_pyproject_toml(pyproject_toml_path)
if config is None:
return True, None

return Confirm.ask(
Expand Down Expand Up @@ -968,6 +977,11 @@ def install_github_app(git_remote: str) -> None:
except git.InvalidGitRepositoryError:
click.echo("Skipping GitHub app installation because you're not in a git repository.")
return

if git_remote not in get_git_remotes(git_repo):
click.echo(f"Skipping GitHub app installation, remote ({git_remote}) does not exist in this repository.")
return

owner, repo = get_repo_owner_and_name(git_repo, git_remote)

if is_github_app_installed_on_repo(owner, repo, suppress_errors=True):
Expand Down
14 changes: 12 additions & 2 deletions codeflash/code_utils/git_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import re
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -201,7 +202,8 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:

def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.commit("-am", commit_message, "--no-verify")
repository.git.add(".")
repository.git.commit("-m", commit_message, "--no-verify")


def create_detached_worktree(module_root: Path) -> Optional[Path]:
Expand All @@ -220,6 +222,14 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)

# HACK: remove binary files from the diff, find a better way # noqa: FIX004
uni_diff_text = re.sub(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rather than regex we can us git diff exclude and list all the exclude file types. or we can alter the .gitignore or gitattributes file too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we can also go with --diff-filter=AMT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am thinking either we can just change the file type to text rather than excluding them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah exclude sounds like a good option

Also I am thinking either we can just change the file type to text rather than excluding them?

actually this shouldn't happen in the first place, git can't apply patches with binary diffs, so this would be the user mistake if he forgot to ignore them

r"^diff --git a\/.*?\.(?:pyc|class|jar|exe|dll|so|dylib|o|obj|bin|pdf|jpg|jpeg|png|gif|zip|tar|gz) b\/.*?\.\w+.*?\n(?:.*?\n)*?(?=diff --git|\Z)",
"",
uni_diff_text,
flags=re.MULTILINE,
)

if not uni_diff_text.strip():
logger.info("No uncommitted changes to copy to worktree.")
return worktree_dir
Expand All @@ -234,7 +244,7 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
# Apply the patch inside the worktree
try:
subprocess.run(
["git", "apply", "--ignore-space-change", "--ignore-whitespace", patch_path],
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
cwd=worktree_dir,
check=True,
)
Expand Down
9 changes: 9 additions & 0 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import FunctionParent
from codeflash.telemetry.posthog_cf import ph

Expand Down Expand Up @@ -168,6 +169,7 @@ def get_functions_to_optimize(
)
functions: dict[str, list[FunctionToOptimize]]
trace_file_path: Path | None = None
is_lsp = is_LSP_enabled()
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=SyntaxWarning)
if optimize_all:
Expand All @@ -185,6 +187,8 @@ def get_functions_to_optimize(
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
if is_lsp:
return functions, 0, None
exit_with_message(
"Function name should be in the format 'function_name' or 'class_name.function_name'"
)
Expand All @@ -200,6 +204,8 @@ def get_functions_to_optimize(
):
found_function = fn
if found_function is None:
if is_lsp:
return functions, 0, None
exit_with_message(
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
)
Expand Down Expand Up @@ -470,6 +476,9 @@ def was_function_previously_optimized(
Tuple of (filtered_functions_dict, remaining_count)

"""
if is_LSP_enabled():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: could this raise questions on underterministic nature of optimizations if the function has is same.

return False

# Check optimization status if repository info is provided
# already_optimized_count = 0
try:
Expand Down
68 changes: 55 additions & 13 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
from pathlib import Path
from typing import TYPE_CHECKING

import git
from pygls import uris

from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.code_utils.git_utils import create_diff_patch_from_worktree
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff
from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol

if TYPE_CHECKING:
from argparse import Namespace

from lsprotocol import types


Expand Down Expand Up @@ -85,9 +89,14 @@ def initialize_function_optimization(
) -> dict[str, str]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")

if server.optimizer is None:
_initialize_optimizer_if_valid(server)
_initialize_optimizer_if_api_key_is_valid(server)

server.optimizer.worktree_mode()
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
Path(server.optimizer.args.tests_root).mkdir(parents=True, exist_ok=True)

original_args, _ = server.optimizer.original_args_and_test_cfg

server.optimizer.args.function = params.functionName
Expand All @@ -99,15 +108,12 @@ def initialize_function_optimization(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)

optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
if not optimizable_funcs:
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()

if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
return {
"functionName": params.functionName,
"status": "error",
"message": "function is no found or not optimizable",
"args": None,
}
cleanup_the_optimizer(server)
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}

fto = optimizable_funcs.popitem()[1][0]
server.optimizer.current_function_being_optimized = fto
Expand All @@ -129,7 +135,33 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt
return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests}


def _initialize_optimizer_if_valid(server: CodeflashLanguageServer) -> dict[str, str]:
@server.feature("validateProject")
def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]:
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml

server.show_message_log("Validating project...", "Info")
config = is_valid_pyproject_toml(server.args.config_file)
if config is None:
server.show_message_log("pyproject.toml is not valid", "Error")
return {
"status": "error",
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
}

args = process_args(server)
repo = git.Repo(args.module_root, search_parent_directories=True)
if repo.bare:
return {"status": "error", "message": "Repository is in bare state"}

try:
_ = repo.head.commit
except Exception:
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}

return {"status": "success"}


def _initialize_optimizer_if_api_key_is_valid(server: CodeflashLanguageServer) -> dict[str, str]:
user_id = get_user_id()
if user_id is None:
return {"status": "error", "message": "api key not found or invalid"}
Expand All @@ -140,14 +172,24 @@ def _initialize_optimizer_if_valid(server: CodeflashLanguageServer) -> dict[str,

from codeflash.optimization.optimizer import Optimizer

server.optimizer = Optimizer(server.args)
new_args = process_args(server)
server.optimizer = Optimizer(new_args)
return {"status": "success", "user_id": user_id}


def process_args(server: CodeflashLanguageServer) -> Namespace:
if server.args_processed_before:
return server.args
new_args = process_pyproject_config(server.args)
server.args = new_args
server.args_processed_before = True
return new_args


@server.feature("apiKeyExistsAndValid")
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
try:
return _initialize_optimizer_if_valid(server)
return _initialize_optimizer_if_api_key_is_valid(server)
except Exception:
return {"status": "error", "message": "something went wrong while validating the api key"}

Expand All @@ -167,7 +209,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
get_codeflash_api_key.cache_clear()
get_user_id.cache_clear()

init_result = _initialize_optimizer_if_valid(server)
init_result = _initialize_optimizer_if_api_key_is_valid(server)
if init_result["status"] == "error":
return {"status": "error", "message": "Api key is not valid"}

Expand Down
4 changes: 2 additions & 2 deletions codeflash/lsp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ class CodeflashLanguageServer(LanguageServer):
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(*args, **kwargs)
self.optimizer: Optimizer | None = None
self.args_processed_before: bool = False
self.args = None

def prepare_optimizer_arguments(self, config_file: Path) -> None:
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cli import parse_args

args = parse_args()
args.config_file = config_file
args.no_pr = True # LSP server should not create PRs
args.worktree = True
args = process_pyproject_config(args)
self.args = args
# avoid initializing the optimizer during initialization, because it can cause an error if the api key is invalid

Expand Down
Loading