Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions codeflash/code_utils/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
from git import Repo


def get_git_diff(repo_directory: Path | None = None, *, uncommitted_changes: bool = False) -> dict[str, list[int]]:
def get_git_diff(
repo_directory: Path | None = None, *, only_this_commit: Optional[str] = None, uncommitted_changes: bool = False
) -> dict[str, list[int]]:
if repo_directory is None:
repo_directory = Path.cwd()
repository = git.Repo(repo_directory, search_parent_directories=True)
commit = repository.head.commit
if uncommitted_changes:
if only_this_commit:
uni_diff_text = repository.git.diff(
only_this_commit + "^1", only_this_commit, ignore_blank_lines=True, ignore_space_at_eol=True
)
elif uncommitted_changes:
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
else:
uni_diff_text = repository.git.diff(
Expand Down
15 changes: 12 additions & 3 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,16 @@ def get_functions_to_optimize(

def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
modified_functions: dict[str, list[FunctionToOptimize]] = {}
return get_functions_within_lines(modified_lines)


def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
return get_functions_within_lines(modified_lines)


def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str, list[FunctionToOptimize]]:
functions: dict[str, list[FunctionToOptimize]] = {}
for path_str, lines_in_file in modified_lines.items():
path = Path(path_str)
if not path.exists():
Expand All @@ -246,14 +255,14 @@ def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[F
continue
function_lines = FunctionVisitor(file_path=str(path))
wrapper.visit(function_lines)
modified_functions[str(path)] = [
functions[str(path)] = [
function_to_optimize
for function_to_optimize in function_lines.functions
if (start_line := function_to_optimize.starting_line) is not None
and (end_line := function_to_optimize.ending_line) is not None
and any(start_line <= line <= end_line for line in lines_in_file)
]
return modified_functions
return functions


def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
Expand Down
31 changes: 29 additions & 2 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.code_utils.git_utils import create_diff_patch_from_worktree
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff
from codeflash.discovery.functions_to_optimize import (
filter_functions,
get_functions_inside_a_commit,
get_functions_within_git_diff,
)
from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol

Expand All @@ -22,6 +26,8 @@

from lsprotocol import types

from codeflash.discovery.functions_to_optimize import FunctionToOptimize


@dataclass
class OptimizableFunctionsParams:
Expand All @@ -39,6 +45,11 @@ class ProvideApiKeyParams:
api_key: str


@dataclass
class OptimizableFunctionsInCommitParams:
commit_hash: str


server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)


Expand All @@ -47,6 +58,22 @@ def get_functions_in_current_git_diff(
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
) -> dict[str, str | dict[str, list[str]]]:
functions = get_functions_within_git_diff(uncommitted_changes=True)
file_to_qualified_names = _group_functions_by_file(server, functions)
return {"functions": file_to_qualified_names, "status": "success"}


@server.feature("getOptimizableFunctionsInCommit")
def get_functions_in_commit(
server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams
) -> dict[str, str | dict[str, list[str]]]:
functions = get_functions_inside_a_commit(params.commit_hash)
file_to_qualified_names = _group_functions_by_file(server, functions)
return {"functions": file_to_qualified_names, "status": "success"}


def _group_functions_by_file(
server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]]
) -> dict[str, list[str]]:
file_to_funcs_to_optimize, _ = filter_functions(
modified_functions=functions,
tests_root=server.optimizer.test_cfg.tests_root,
Expand All @@ -58,7 +85,7 @@ def get_functions_in_current_git_diff(
file_to_qualified_names: dict[str, list[str]] = {
str(path): [f.qualified_name for f in funcs] for path, funcs in file_to_funcs_to_optimize.items()
}
return {"functions": file_to_qualified_names, "status": "success"}
return file_to_qualified_names


@server.feature("getOptimizableFunctions")
Expand Down
Loading