Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version

Expand Down Expand Up @@ -223,20 +222,18 @@ def process_pyproject_config(args: Namespace) -> Namespace:
args.module_root = Path(args.module_root).resolve()
# If module-root is "." then all imports are relatives to it.
# in this case, the ".." becomes outside project scope, causing issues with un-importable paths
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path, args.worktree)
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path)
args.tests_root = Path(args.tests_root).resolve()
if args.benchmarks_root:
args.benchmarks_root = Path(args.benchmarks_root).resolve()
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path, args.worktree)
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
if is_LSP_enabled():
args.all = None
return args
return handle_optimize_all_arg_parsing(args)


def project_root_from_module_root(module_root: Path, pyproject_file_path: Path, in_worktree: bool = False) -> Path: # noqa: FBT001, FBT002
if in_worktree:
return git_root_dir()
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
if pyproject_file_path.parent == module_root:
return module_root
return module_root.parent.resolve()
Expand Down
7 changes: 1 addition & 6 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,19 @@ def get_optimizable_functions(
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
) -> dict[str, list[str]]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
if not server.optimizer:
return {"status": "error", "message": "optimizer not initialized"}

server.optimizer.args.file = file_path
server.optimizer.args.function = None # Always get ALL functions, not just one
server.optimizer.args.previous_checkpoint_functions = False

server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info")
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()

path_to_qualified_names = {}
for functions in optimizable_funcs.values():
path_to_qualified_names[file_path] = [func.qualified_name for func in functions]

server.show_message_log(
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
)
return path_to_qualified_names


Expand Down Expand Up @@ -177,7 +172,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
else:
return {"status": "error", "message": "No pyproject.toml found in workspace."}

# since we are using worktrees, optimization diffs are generated with respect to the root of the repo, also the args.project_root is set to the root of the repo when creating a worktree
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
root = str(git_root_dir())

if getattr(params, "skip_validation", False):
Expand Down
63 changes: 38 additions & 25 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
from codeflash.code_utils.git_utils import check_running_in_git_repo
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
from codeflash.code_utils.git_worktree_utils import (
create_detached_worktree,
create_diff_patch_from_worktree,
Expand Down Expand Up @@ -442,40 +442,53 @@ def worktree_mode(self) -> None:
logger.warning("Failed to create worktree. Skipping optimization.")
return
self.current_worktree = worktree_dir
self.mutate_args_for_worktree_mode(worktree_dir)
self.mirror_paths_for_worktree_mode(worktree_dir)
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)

def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
saved_args = copy.deepcopy(self.args)
saved_test_cfg = copy.deepcopy(self.test_cfg)
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
def mirror_paths_for_worktree_mode(self, worktree_dir: Path) -> None:
original_args = copy.deepcopy(self.args)
original_test_cfg = copy.deepcopy(self.test_cfg)
self.original_args_and_test_cfg = (original_args, original_test_cfg)

project_root = self.args.project_root
module_root = self.args.module_root
relative_module_root = module_root.relative_to(project_root)
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
relative_benchmarks_root = (
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
)
original_module_root = original_args.module_root
original_git_root = git_root_dir()

self.args.module_root = worktree_dir / relative_module_root
self.args.project_root = worktree_dir
self.args.test_project_root = worktree_dir
self.args.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
# mutate project_root
relative_project_root = original_args.project_root.relative_to(original_git_root)
# this will be the same as the original project root but in the worktree
new_project_root = worktree_dir / relative_project_root
self.args.project_root = new_project_root
self.test_cfg.project_root_path = new_project_root

self.test_cfg.project_root_path = worktree_dir
self.test_cfg.tests_project_rootdir = worktree_dir
self.test_cfg.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
# mutate module_root
relative_module_root = original_module_root.relative_to(original_git_root)
self.args.module_root = worktree_dir / relative_module_root

# mute target file
relative_optimized_file = original_args.file.relative_to(original_git_root) if original_args.file else None
if relative_optimized_file is not None:
self.args.file = worktree_dir / relative_optimized_file

# mutate tests root
relative_tests_root = original_test_cfg.tests_root.relative_to(original_git_root)
new_tests_root = worktree_dir / relative_tests_root
self.args.tests_root = new_tests_root
self.test_cfg.tests_root = new_tests_root

# mutate tests project root
relative_tests_project_root = original_args.test_project_root.relative_to(original_git_root)
self.args.test_project_root = worktree_dir / relative_tests_project_root
self.test_cfg.tests_project_rootdir = worktree_dir / relative_tests_project_root

# mutate benchmarks root
relative_benchmarks_root = (
original_args.benchmarks_root.relative_to(original_git_root) if original_args.benchmarks_root else None
)
if relative_benchmarks_root:
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root


def run_with_args(args: Namespace) -> None:
optimizer = None
Expand Down
Loading