diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 176d928c5..9eb5b7c20 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -10,8 +10,9 @@ from pydantic.json import pydantic_encoder from codeflash.cli_cmds.console import console, logger -from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled +from codeflash.code_utils.env_utils import get_codeflash_api_key from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name +from codeflash.lsp.helpers import is_LSP_enabled from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate from codeflash.telemetry.posthog_cf import ph diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 2aff42a1a..25b155ddf 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -14,8 +14,9 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number -from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir +from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name from codeflash.github.PrComment import FileDiffContent, PrComment +from codeflash.lsp.helpers import is_LSP_enabled from codeflash.version import __version__ if TYPE_CHECKING: @@ -101,7 +102,7 @@ def get_user_id() -> Optional[str]: if min_version and version.parse(min_version) > version.parse(__version__): msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`." console.print(f"[bold red]{msg}[/bold red]") - if console.quiet: # lsp + if is_LSP_enabled(): logger.debug(msg) return f"Error: {msg}" sys.exit(1) @@ -203,8 +204,9 @@ def create_staging( generated_original_test_source: str, function_trace_id: str, coverage_message: str, - replay_tests: str = "", - concolic_tests: str = "", + replay_tests: str, + concolic_tests: str, + root_dir: Path, ) -> Response: """Create a staging pull request, targeting the specified branch. (usually 'staging'). @@ -217,12 +219,10 @@ def create_staging( :param coverage_message: Coverage report or summary. :return: The response object from the backend. """ - relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix() + relative_path = explanation.file_path.relative_to(root_dir).as_posix() build_file_changes = { - Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent( - oldContent=original_code[p], newContent=new_code[p] - ) + Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p]) for p in original_code } diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 34e5ad223..5aa331f36 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -94,6 +94,7 @@ def parse_args() -> Namespace: help="Path to the directory of the project, where all the pytest-benchmark tests are located.", ) parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs") + parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization") args, unknown_args = parser.parse_known_args() sys.argv[:] = [sys.argv[0], *unknown_args] diff --git a/codeflash/cli_cmds/console.py b/codeflash/cli_cmds/console.py index 34d50f268..ca8e29e26 100644 --- a/codeflash/cli_cmds/console.py +++ b/codeflash/cli_cmds/console.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from contextlib import contextmanager from itertools import cycle from typing import TYPE_CHECKING @@ -28,6 +29,10 @@ DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG console = Console() + +if os.getenv("CODEFLASH_LSP"): + console.quiet = True + logging.basicConfig( level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)], diff --git a/codeflash/code_utils/coverage_utils.py b/codeflash/code_utils/coverage_utils.py index 21aa06ad9..966910630 100644 --- a/codeflash/code_utils/coverage_utils.py +++ b/codeflash/code_utils/coverage_utils.py @@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio return full_name -def generate_candidates(source_code_path: Path) -> list[str]: +def generate_candidates(source_code_path: Path) -> set[str]: """Generate all the possible candidates for coverage data based on the source code path.""" - candidates = [source_code_path.name] + candidates = set() + candidates.add(source_code_path.name) current_path = source_code_path.parent + last_added = source_code_path.name while current_path != current_path.parent: - candidate_path = str(Path(current_path.name) / candidates[-1]) - candidates.append(candidate_path) + candidate_path = str(Path(current_path.name) / last_added) + candidates.add(candidate_path) + last_added = candidate_path current_path = current_path.parent + candidates.add(str(source_code_path)) return candidates diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index ab0cb16e6..eca59bfa8 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -7,10 +7,11 @@ from pathlib import Path from typing import Any, Optional -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import exit_with_message from codeflash.code_utils.formatter import format_code from codeflash.code_utils.shell_utils import read_api_key_from_shell_config +from codeflash.lsp.helpers import is_LSP_enabled def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa @@ -34,11 +35,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = @lru_cache(maxsize=1) def get_codeflash_api_key() -> str: - if console.quiet: # lsp - # prefer shell config over env var in lsp mode - api_key = read_api_key_from_shell_config() - else: - api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config() + # prefer shell config over env var in lsp mode + api_key = ( + read_api_key_from_shell_config() + if is_LSP_enabled() + else os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config() + ) api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa if not api_key: @@ -125,11 +127,6 @@ def is_ci() -> bool: return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS")) -@lru_cache(maxsize=1) -def is_LSP_enabled() -> bool: - return console.quiet - - def is_pr_draft() -> bool: """Check if the PR is draft. in the github action context.""" event = get_cached_gh_event_data() diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index e3a412734..eff9f4ed4 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -13,6 +13,7 @@ import isort from codeflash.cli_cmds.console import console, logger +from codeflash.lsp.helpers import is_LSP_enabled def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str: @@ -106,8 +107,7 @@ def format_code( print_status: bool = True, # noqa exit_on_failure: bool = True, # noqa ) -> str: - if console.quiet: - # lsp mode + if is_LSP_enabled(): exit_on_failure = False if isinstance(path, str): diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index 00f9f5e28..b4f658097 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -9,13 +9,14 @@ from functools import cache from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import git from rich.prompt import Confirm from unidiff import PatchSet from codeflash.cli_cmds.console import logger +from codeflash.code_utils.compat import codeflash_cache_dir from codeflash.code_utils.config_consts import N_CANDIDATES if TYPE_CHECKING: @@ -192,3 +193,80 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None: return None else: return last_commit.author.name + + +worktree_dirs = codeflash_cache_dir / "worktrees" +patches_dir = codeflash_cache_dir / "patches" + + +def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None: + repository = git.Repo(worktree_dir, search_parent_directories=True) + repository.git.commit("-am", commit_message, "--no-verify") + + +def create_detached_worktree(module_root: Path) -> Optional[Path]: + if not check_running_in_git_repo(module_root): + logger.warning("Module is not in a git repository. Skipping worktree creation.") + return None + git_root = git_root_dir() + current_time_str = time.strftime("%Y%m%d-%H%M%S") + worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}" + + repository = git.Repo(git_root, search_parent_directories=True) + + repository.git.worktree("add", "-d", str(worktree_dir)) + + # Get uncommitted diff from the original repo + repository.git.add("-N", ".") # add the index for untracked files to be included in the diff + uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True) + + if not uni_diff_text.strip(): + logger.info("No uncommitted changes to copy to worktree.") + return worktree_dir + + # Write the diff to a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file: + tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid + tmp_patch_file.flush() + + patch_path = Path(tmp_patch_file.name).resolve() + + # Apply the patch inside the worktree + try: + subprocess.run( + ["git", "apply", "--ignore-space-change", "--ignore-whitespace", patch_path], + cwd=worktree_dir, + check=True, + ) + create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot") + except subprocess.CalledProcessError as e: + logger.error(f"Failed to apply patch to worktree: {e}") + + return worktree_dir + + +def remove_worktree(worktree_dir: Path) -> None: + try: + repository = git.Repo(worktree_dir, search_parent_directories=True) + repository.git.worktree("remove", "--force", worktree_dir) + except Exception: + logger.exception(f"Failed to remove worktree: {worktree_dir}") + + +def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path: + repository = git.Repo(worktree_dir, search_parent_directories=True) + uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True) + + if not uni_diff_text: + logger.warning("No changes found in worktree.") + return None + + if not uni_diff_text.endswith("\n"): + uni_diff_text += "\n" + + # write to patches_dir + patches_dir.mkdir(parents=True, exist_ok=True) + patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch" + with patch_path.open("w", encoding="utf8") as f: + f.write(uni_diff_text) + return patch_path diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3c2f4f633..db5d5e5d5 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -208,7 +208,7 @@ def get_functions_to_optimize( logger.info("Finding all functions modified in the current git diff ...") console.rule() ph("cli-optimizing-git-diff") - functions = get_functions_within_git_diff() + functions = get_functions_within_git_diff(uncommitted_changes=False) filtered_modified_functions, functions_count = filter_functions( functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions ) @@ -224,8 +224,8 @@ def get_functions_to_optimize( return filtered_modified_functions, functions_count, trace_file_path -def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]: - modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False) +def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001 + modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes) modified_functions: dict[str, list[FunctionToOptimize]] = {} for path_str, lines_in_file in modified_lines.items(): path = Path(path_str) diff --git a/codeflash/lsp/__init__.py b/codeflash/lsp/__init__.py index 9d75096ef..e69de29bb 100644 --- a/codeflash/lsp/__init__.py +++ b/codeflash/lsp/__init__.py @@ -1,4 +0,0 @@ -# Silence the console module to prevent stdout pollution -from codeflash.cli_cmds.console import console - -console.quiet = True diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 4ed0b7a62..f83be319f 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -9,9 +9,12 @@ from pygls import uris from codeflash.api.cfapi import get_codeflash_api_key, get_user_id +from codeflash.code_utils.git_utils import create_diff_patch_from_worktree from codeflash.code_utils.shell_utils import save_api_key_to_rc +from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff from codeflash.either import is_successful from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol +from codeflash.result.explanation import Explanation if TYPE_CHECKING: from lsprotocol import types @@ -38,6 +41,23 @@ class ProvideApiKeyParams: server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) +@server.feature("getOptimizableFunctionsInCurrentDiff") +def get_functions_in_current_git_diff( + server: CodeflashLanguageServer, _params: OptimizableFunctionsParams +) -> dict[str, str | list[str]]: + functions = get_functions_within_git_diff(uncommitted_changes=True) + file_to_funcs_to_optimize, _ = filter_functions( + modified_functions=functions, + tests_root=server.optimizer.test_cfg.tests_root, + ignore_paths=[], + project_root=server.optimizer.args.project_root, + module_root=server.optimizer.args.module_root, + previous_checkpoint_functions={}, + ) + qualified_names: list[str] = [func.qualified_name for funcs in file_to_funcs_to_optimize.values() for func in funcs] + return {"functions": qualified_names, "status": "success"} + + @server.feature("getOptimizableFunctions") def get_optimizable_functions( server: CodeflashLanguageServer, params: OptimizableFunctionsParams @@ -45,44 +65,21 @@ def get_optimizable_functions( file_path = Path(uris.to_fs_path(params.textDocument.uri)) server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info") - # Save original args to restore later - original_file = getattr(server.optimizer.args, "file", None) - original_function = getattr(server.optimizer.args, "function", None) - original_checkpoint = getattr(server.optimizer.args, "previous_checkpoint_functions", None) - - server.show_message_log(f"Original args - file: {original_file}, function: {original_function}", "Info") - - try: - # Set temporary args for this request only - server.optimizer.args.file = file_path - server.optimizer.args.function = None # Always get ALL functions, not just one - server.optimizer.args.previous_checkpoint_functions = False + server.optimizer.args.file = file_path + server.optimizer.args.function = None # Always get ALL functions, not just one + server.optimizer.args.previous_checkpoint_functions = False - server.show_message_log("Calling get_optimizable_functions...", "Info") - optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions() + server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info") + optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions() - path_to_qualified_names = {} - for path, functions in optimizable_funcs.items(): - path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions] + path_to_qualified_names = {} + for functions in optimizable_funcs.values(): + path_to_qualified_names[file_path] = [func.qualified_name for func in functions] - server.show_message_log( - f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info" - ) - return path_to_qualified_names - finally: - # Restore original args to prevent state corruption - if original_file is not None: - server.optimizer.args.file = original_file - if original_function is not None: - server.optimizer.args.function = original_function - else: - server.optimizer.args.function = None - if original_checkpoint is not None: - server.optimizer.args.previous_checkpoint_functions = original_checkpoint - - server.show_message_log( - f"Restored args - file: {server.optimizer.args.file}, function: {server.optimizer.args.function}", "Info" - ) + server.show_message_log( + f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info" + ) + return path_to_qualified_names @server.feature("initializeFunctionOptimization") @@ -91,10 +88,15 @@ def initialize_function_optimization( ) -> dict[str, str]: file_path = Path(uris.to_fs_path(params.textDocument.uri)) server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info") + if server.optimizer is None: + _initialize_optimizer_if_valid(server) + server.optimizer.worktree_mode() + original_args, _ = server.optimizer.original_args_and_test_cfg - # IMPORTANT: Store the specific function for optimization, but don't corrupt global state server.optimizer.args.function = params.functionName - server.optimizer.args.file = file_path + original_relative_file_path = file_path.relative_to(original_args.project_root) + server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path + server.optimizer.args.previous_checkpoint_functions = False server.show_message_log( f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" @@ -103,7 +105,12 @@ def initialize_function_optimization( optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions() if not optimizable_funcs: server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning") - return {"functionName": params.functionName, "status": "not found", "args": None} + return { + "functionName": params.functionName, + "status": "error", + "message": "function is no found or not optimizable", + "args": None, + } fto = optimizable_funcs.popitem()[1][0] server.optimizer.current_function_being_optimized = fto @@ -237,112 +244,133 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization def perform_function_optimization( # noqa: PLR0911 server: CodeflashLanguageServer, params: FunctionOptimizationParams ) -> dict[str, str]: - server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") - current_function = server.optimizer.current_function_being_optimized + try: + server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") + current_function = server.optimizer.current_function_being_optimized + + if not current_function: + server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") + return { + "functionName": params.functionName, + "status": "error", + "message": "No function currently being optimized", + } + + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) + + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = server.optimizer.create_function_optimizer( + current_function, + function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=current_function.file_path, + function_to_tests=server.optimizer.discovered_tests or {}, + ) - if not current_function: - server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") - return { - "functionName": params.functionName, - "status": "error", - "message": "No function currently being optimized", - } + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} - module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} - validated_original_code, original_module_ast = module_prep_result + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - function_optimizer = server.optimizer.create_function_optimizer( - current_function, - function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, - original_module_ast=original_module_ast, - original_module_path=current_function.file_path, - function_to_tests=server.optimizer.discovered_tests or {}, - ) - - server.optimizer.current_function_optimizer = function_optimizer - if not function_optimizer: - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + test_setup_result = function_optimizer.generate_and_instrument_tests( + code_context, should_run_experiment=should_run_experiment + ) + if not is_successful(test_setup_result): + return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) = test_setup_result.unwrap() + + baseline_setup_result = function_optimizer.setup_and_establish_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + function_to_concolic_tests=function_to_concolic_tests, + generated_test_paths=generated_test_paths, + generated_perf_test_paths=generated_perf_test_paths, + instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, + original_conftest_content=original_conftest_content, + ) - initialization_result = function_optimizer.can_be_optimized() - if not is_successful(initialization_result): - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + if not is_successful(baseline_setup_result): + return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} + + ( + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, + ) = baseline_setup_result.unwrap() + + best_optimization = function_optimizer.find_and_process_best_optimization( + optimizations_set=optimizations_set, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + function_to_optimize_qualified_name=function_to_optimize_qualified_name, + function_to_all_tests=function_to_all_tests, + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, + concolic_test_str=concolic_test_str, + ) - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + if not best_optimization: + server.show_message_log( + f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" + ) + return { + "functionName": params.functionName, + "status": "error", + "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", + } + + # generate a patch for the optimization + relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings] + patch_file = create_diff_patch_from_worktree( + server.optimizer.current_worktree, + relative_file_paths, + server.optimizer.current_function_optimizer.function_to_optimize.qualified_name, + ) - test_setup_result = function_optimizer.generate_and_instrument_tests( - code_context, should_run_experiment=should_run_experiment - ) - if not is_successful(test_setup_result): - return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} - ( - generated_tests, - function_to_concolic_tests, - concolic_test_str, - optimizations_set, - generated_test_paths, - generated_perf_test_paths, - instrumented_unittests_created_for_function, - original_conftest_content, - ) = test_setup_result.unwrap() - - baseline_setup_result = function_optimizer.setup_and_establish_baseline( - code_context=code_context, - original_helper_code=original_helper_code, - function_to_concolic_tests=function_to_concolic_tests, - generated_test_paths=generated_test_paths, - generated_perf_test_paths=generated_perf_test_paths, - instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, - original_conftest_content=original_conftest_content, - ) + optimized_source = best_optimization.candidate.source_code.markdown + speedup = original_code_baseline.runtime / best_optimization.runtime - if not is_successful(baseline_setup_result): - return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} - - ( - function_to_optimize_qualified_name, - function_to_all_tests, - original_code_baseline, - test_functions_to_remove, - file_path_to_helper_classes, - ) = baseline_setup_result.unwrap() - - best_optimization = function_optimizer.find_and_process_best_optimization( - optimizations_set=optimizations_set, - code_context=code_context, - original_code_baseline=original_code_baseline, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - function_to_optimize_qualified_name=function_to_optimize_qualified_name, - function_to_all_tests=function_to_all_tests, - generated_tests=generated_tests, - test_functions_to_remove=test_functions_to_remove, - concolic_test_str=concolic_test_str, - ) + server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") - if not best_optimization: - server.show_message_log( - f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" - ) + explanation = best_optimization.candidate.explanation + explanation_str = explanation.explanation_message() if isinstance(explanation, Explanation) else explanation return { "functionName": params.functionName, - "status": "error", - "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", + "status": "success", + "message": "Optimization completed successfully", + "extra": f"Speedup: {speedup:.2f}x faster", + "optimization": optimized_source, + "patch_file": str(patch_file), + "explanation": explanation_str, } + finally: + cleanup_the_optimizer(server) - optimized_source = best_optimization.candidate.source_code.markdown - speedup = original_code_baseline.runtime / best_optimization.runtime - - server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") - # CRITICAL: Clear the function filter after optimization to prevent state corruption +def cleanup_the_optimizer(server: CodeflashLanguageServer) -> None: + server.optimizer.cleanup_temporary_paths() + # restore args and test cfg + if server.optimizer.original_args_and_test_cfg: + server.optimizer.args, server.optimizer.test_cfg = server.optimizer.original_args_and_test_cfg server.optimizer.args.function = None - server.show_message_log("Cleared function filter to prevent state corruption", "Info") - - return { - "functionName": params.functionName, - "status": "success", - "message": "Optimization completed successfully", - "extra": f"Speedup: {speedup:.2f}x faster", - "optimization": optimized_source, - } + server.optimizer.current_worktree = None + server.optimizer.current_function_optimizer = None diff --git a/codeflash/lsp/helpers.py b/codeflash/lsp/helpers.py new file mode 100644 index 000000000..dc8f8c5d6 --- /dev/null +++ b/codeflash/lsp/helpers.py @@ -0,0 +1,7 @@ +import os +from functools import lru_cache + + +@lru_cache(maxsize=1) +def is_LSP_enabled() -> bool: + return os.getenv("CODEFLASH_LSP", default="false").lower() == "true" diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py index 030574a8d..8114c94dc 100644 --- a/codeflash/lsp/server.py +++ b/codeflash/lsp/server.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: from lsprotocol.types import InitializeParams, InitializeResult + from codeflash.optimization.optimizer import Optimizer + class CodeflashLanguageServerProtocol(LanguageServerProtocol): _server: CodeflashLanguageServer @@ -26,7 +28,6 @@ def lsp_initialize(self, params: InitializeParams) -> InitializeResult: pyproject_toml_path = self._find_pyproject_toml(workspace_path) if pyproject_toml_path: server.prepare_optimizer_arguments(pyproject_toml_path) - server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}") else: server.show_message("No pyproject.toml found in workspace.") else: @@ -44,7 +45,7 @@ def _find_pyproject_toml(self, workspace_path: str) -> Path | None: class CodeflashLanguageServer(LanguageServer): def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 super().__init__(*args, **kwargs) - self.optimizer = None + self.optimizer: Optimizer | None = None self.args = None def prepare_optimizer_arguments(self, config_file: Path) -> None: @@ -53,6 +54,7 @@ def prepare_optimizer_arguments(self, config_file: Path) -> None: args = parse_args() args.config_file = config_file args.no_pr = True # LSP server should not create PRs + args.worktree = True args = process_pyproject_config(args) self.args = args # avoid initializing the optimizer during initialization, because it can cause an error if the api key is invalid diff --git a/codeflash/lsp/server_entry.py b/codeflash/lsp/server_entry.py index dfb8dd5a3..5ea30cde8 100644 --- a/codeflash/lsp/server_entry.py +++ b/codeflash/lsp/server_entry.py @@ -21,7 +21,8 @@ def setup_logging() -> logging.Logger: # Set up stderr handler for VS Code output channel with [LSP-Server] prefix handler = logging.StreamHandler(sys.stderr) - handler.setFormatter(logging.Formatter("[LSP-Server] %(asctime)s [%(levelname)s]: %(message)s")) + # adding the :::: here for the client to easily extract the message from the log + handler.setFormatter(logging.Formatter("[LSP-Server] %(asctime)s [%(levelname)s]::::%(message)s")) # Configure root logger root_logger.addHandler(handler) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 1636d6889..1d47fb7d4 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -343,7 +343,7 @@ class TestsInFile: test_type: TestType -@dataclass(frozen=True) +@dataclass class OptimizedCandidate: source_code: CodeStringsMarkdown explanation: str diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0253c9ac3..313f15d06 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1232,7 +1232,11 @@ def process_review( file_path=explanation.file_path, benchmark_details=explanation.benchmark_details, ) + + best_optimization.candidate.explanation = new_explanation + console.print(Panel(new_explanation_raw_str, title="Best Candidate Explanation", border_style="blue")) + data = { "original_code": original_code_combined, "new_code": new_code_combined, @@ -1245,6 +1249,7 @@ def process_review( "coverage_message": coverage_message, "replay_tests": replay_tests, "concolic_tests": concolic_tests, + "root_dir": self.project_root, } raise_pr = not self.args.no_pr @@ -1260,6 +1265,10 @@ def process_review( trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None ) + # If worktree mode, do not revert code and helpers,, otherwise we would have an empty diff when writing the patch in the lsp + if self.args.worktree: + return + if raise_pr and ( self.args.all or env_utils.get_pr_number() diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 83073d6d9..ce857f201 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import copy import os import tempfile import time @@ -14,6 +15,13 @@ from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft +from codeflash.code_utils.git_utils import ( + check_running_in_git_repo, + create_detached_worktree, + create_diff_patch_from_worktree, + create_worktree_snapshot_commit, + remove_worktree, +) from codeflash.either import is_successful from codeflash.models.models import ValidCode from codeflash.telemetry.posthog_cf import ph @@ -48,6 +56,9 @@ def __init__(self, args: Namespace) -> None: self.functions_checkpoint: CodeflashRunCheckpoint | None = None self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP self.current_function_optimizer: FunctionOptimizer | None = None + self.current_worktree: Path | None = None + self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None + self.patch_files: list[Path] = [] def run_benchmarks( self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int @@ -252,6 +263,10 @@ def run(self) -> None: if self.args.no_draft and is_pr_draft(): logger.warning("PR is in draft mode, skipping optimization") return + + if self.args.worktree: + self.worktree_mode() + cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root)) function_optimizer = None @@ -260,7 +275,6 @@ def run(self) -> None: file_to_funcs_to_optimize, num_optimizable_functions ) optimizations_found: int = 0 - function_iterator_count: int = 0 if self.args.test_framework == "pytest": self.test_cfg.concolic_test_root_dir = Path( tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_") @@ -296,8 +310,8 @@ def run(self) -> None: except Exception as e: logger.debug(f"Could not rank functions in {original_module_path}: {e}") - for function_to_optimize in functions_to_optimize: - function_iterator_count += 1 + for i, function_to_optimize in enumerate(functions_to_optimize): + function_iterator_count = i + 1 logger.info( f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: " f"{function_to_optimize.qualified_name}" @@ -327,6 +341,23 @@ def run(self) -> None: ) if is_successful(best_optimization): optimizations_found += 1 + # create a diff patch for successful optimization + if self.current_worktree: + read_writable_code = best_optimization.unwrap().code_context.read_writable_code + relative_file_paths = [ + code_string.file_path for code_string in read_writable_code.code_strings + ] + patch_path = create_diff_patch_from_worktree( + self.current_worktree, + relative_file_paths, + self.current_function_optimizer.function_to_optimize.qualified_name, + ) + self.patch_files.append(patch_path) + if i < len(functions_to_optimize) - 1: + create_worktree_snapshot_commit( + self.current_worktree, + f"Optimizing {functions_to_optimize[i + 1].qualified_name}", + ) else: logger.warning(best_optimization.failure()) console.rule() @@ -337,6 +368,10 @@ def run(self) -> None: function_optimizer.cleanup_generated_files() ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found}) + if len(self.patch_files) > 0: + logger.info( + f"Created {len(self.patch_files)} patch(es) ({[str(patch_path) for patch_path in self.patch_files]})" + ) if self.functions_checkpoint: self.functions_checkpoint.cleanup() if hasattr(self.args, "command") and self.args.command == "optimize": @@ -382,14 +417,60 @@ def cleanup_replay_tests(self) -> None: cleanup_paths([self.replay_tests_dir]) def cleanup_temporary_paths(self) -> None: - if self.current_function_optimizer: - self.current_function_optimizer.cleanup_generated_files() - if hasattr(get_run_tmp_file, "tmpdir"): get_run_tmp_file.tmpdir.cleanup() del get_run_tmp_file.tmpdir + + if self.current_worktree: + remove_worktree(self.current_worktree) + return + + if self.current_function_optimizer: + self.current_function_optimizer.cleanup_generated_files() cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir]) + def worktree_mode(self) -> None: + if self.current_worktree: + return + + if check_running_in_git_repo(self.args.module_root): + worktree_dir = create_detached_worktree(self.args.module_root) + if worktree_dir is None: + logger.warning("Failed to create worktree. Skipping optimization.") + return + self.current_worktree = worktree_dir + self.mutate_args_for_worktree_mode(worktree_dir) + + def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None: + saved_args = copy.deepcopy(self.args) + saved_test_cfg = copy.deepcopy(self.test_cfg) + self.original_args_and_test_cfg = (saved_args, saved_test_cfg) + + project_root = self.args.project_root + module_root = self.args.module_root + relative_module_root = module_root.relative_to(project_root) + relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None + relative_tests_root = self.test_cfg.tests_root.relative_to(project_root) + relative_benchmarks_root = ( + self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None + ) + + self.args.module_root = worktree_dir / relative_module_root + self.args.project_root = worktree_dir + self.args.test_project_root = worktree_dir + self.args.tests_root = worktree_dir / relative_tests_root + if relative_benchmarks_root: + self.args.benchmarks_root = worktree_dir / relative_benchmarks_root + + self.test_cfg.project_root_path = worktree_dir + self.test_cfg.tests_project_rootdir = worktree_dir + self.test_cfg.tests_root = worktree_dir / relative_tests_root + if relative_benchmarks_root: + self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root + + if relative_optimized_file is not None: + self.args.file = worktree_dir / relative_optimized_file + def run_with_args(args: Namespace) -> None: optimizer = None diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 51f6b48af..cf842125e 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -10,12 +10,7 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import is_zero_diff -from codeflash.code_utils.git_utils import ( - check_and_push_branch, - get_current_branch, - get_repo_owner_and_name, - git_root_dir, -) +from codeflash.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name from codeflash.code_utils.github_utils import github_pr_url from codeflash.code_utils.tabulate import tabulate from codeflash.code_utils.time_utils import format_perf, format_time @@ -188,6 +183,7 @@ def check_create_pr( coverage_message: str, replay_tests: str, concolic_tests: str, + root_dir: Path, git_remote: Optional[str] = None, ) -> None: pr_number: Optional[int] = env_utils.get_pr_number() @@ -196,9 +192,9 @@ def check_create_pr( if pr_number is not None: logger.info(f"Suggesting changes to PR #{pr_number} ...") owner, repo = get_repo_owner_and_name(git_repo) - relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix() + relative_path = explanation.file_path.relative_to(root_dir).as_posix() build_file_changes = { - Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent( + Path(p).relative_to(root_dir).as_posix(): FileDiffContent( oldContent=original_code[p], newContent=new_code[p] ) for p in original_code @@ -247,10 +243,10 @@ def check_create_pr( if not check_and_push_branch(git_repo, git_remote, wait_for_push=True): logger.warning("⏭️ Branch is not pushed, skipping PR creation...") return - relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix() + relative_path = explanation.file_path.relative_to(root_dir).as_posix() base_branch = get_current_branch() build_file_changes = { - Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent( + Path(p).relative_to(root_dir).as_posix(): FileDiffContent( oldContent=original_code[p], newContent=new_code[p] ) for p in original_code diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index 3fe72a764..8f30a1562 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -10,9 +10,9 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.code_utils.concolic_utils import clean_concolic_tests -from codeflash.code_utils.env_utils import is_LSP_enabled from codeflash.code_utils.static_analysis import has_typed_parameters from codeflash.discovery.discover_unit_tests import discover_unit_tests +from codeflash.lsp.helpers import is_LSP_enabled from codeflash.telemetry.posthog_cf import ph from codeflash.verification.verification_utils import TestConfig diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index a10f50a56..0ab78d2ef 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -357,7 +357,7 @@ def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) -> def test_generate_candidates() -> None: source_code_path = Path("/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py") - expected_candidates = [ + expected_candidates = { "coverage_utils.py", "code_utils/coverage_utils.py", "codeflash/code_utils/coverage_utils.py", @@ -367,7 +367,8 @@ def test_generate_candidates() -> None: "Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py", "krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py", "Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py", - ] + "/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py" + } assert generate_candidates(source_code_path) == expected_candidates