diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9a9af7cd8..7756e481c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,4 +5,4 @@ repos: # Run the linter. - id: ruff-check # Run the formatter. - - id: ruff-format \ No newline at end of file + - id: ruff-format diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 4ed0b7a62..9611d4e11 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -9,6 +9,7 @@ from pygls import uris from codeflash.api.cfapi import get_codeflash_api_key, get_user_id +from codeflash.code_utils.git_utils import create_git_worktrees, create_worktree_root_dir, remove_git_worktrees from codeflash.code_utils.shell_utils import save_api_key_to_rc from codeflash.either import is_successful from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol @@ -238,111 +239,281 @@ def perform_function_optimization( # noqa: PLR0911 server: CodeflashLanguageServer, params: FunctionOptimizationParams ) -> dict[str, str]: server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") - current_function = server.optimizer.current_function_being_optimized + + try: + current_function = server.optimizer.current_function_being_optimized + + if not current_function: + server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") + return { + "functionName": params.functionName, + "status": "error", + "message": "No function currently being optimized", + } + + server.show_message_log(f"Preparing module for optimization: {current_function.file_path}", "Info") + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) + + validated_original_code, original_module_ast = module_prep_result + + server.show_message_log(f"Creating function optimizer for: {params.functionName}", "Info") + function_optimizer = server.optimizer.create_function_optimizer( + current_function, + function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=current_function.file_path, + function_to_tests=server.optimizer.discovered_tests or {}, + ) + + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + + server.show_message_log(f"Checking if {params.functionName} can be optimized", "Info") + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + + server.show_message_log(f"Generating and instrumenting tests for: {params.functionName}", "Info") + test_setup_result = function_optimizer.generate_and_instrument_tests( + code_context, should_run_experiment=should_run_experiment + ) + if not is_successful(test_setup_result): + return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) = test_setup_result.unwrap() + + server.show_message_log(f"Setting up baseline for: {params.functionName}", "Info") + baseline_setup_result = function_optimizer.setup_and_establish_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + function_to_concolic_tests=function_to_concolic_tests, + generated_test_paths=generated_test_paths, + generated_perf_test_paths=generated_perf_test_paths, + instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, + original_conftest_content=original_conftest_content, + ) + + if not is_successful(baseline_setup_result): + return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} + + ( + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, + ) = baseline_setup_result.unwrap() + + # Create worktrees for optimization candidates if in git repo + server.show_message_log(f"Creating worktrees for optimization candidates of: {params.functionName}", "Info") + git_root, worktree_root_dir = None, None + worktree_root, worktrees = None, [] + + try: + from pathlib import Path + module_root = Path(current_function.file_path).parent + git_root, worktree_root_dir = create_worktree_root_dir(module_root) + + if git_root and worktree_root_dir: + worktree_root, worktrees = create_git_worktrees(git_root, worktree_root_dir, module_root) + server.show_message_log(f"Created {len(worktrees)} worktrees for {params.functionName}", "Info") + else: + server.show_message_log(f"Not in git repo, skipping worktree creation for {params.functionName}", "Info") + except Exception as e: + server.show_message_log(f"Failed to create worktrees for {params.functionName}: {e}", "Warning") + + server.show_message_log(f"Finding best optimization for: {params.functionName}", "Info") + best_optimization = function_optimizer.find_and_process_best_optimization( + optimizations_set=optimizations_set, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + function_to_optimize_qualified_name=function_to_optimize_qualified_name, + function_to_all_tests=function_to_all_tests, + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, + concolic_test_str=concolic_test_str, + ) + + # Clean up worktrees after optimization + try: + if worktrees and worktree_root: + server.show_message_log(f"Cleaning up worktrees for {params.functionName}", "Info") + remove_git_worktrees(worktree_root, worktrees) + except Exception as e: + server.show_message_log(f"Failed to cleanup worktrees for {params.functionName}: {e}", "Warning") + + if not best_optimization: + server.show_message_log( + f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" + ) + return { + "functionName": params.functionName, + "status": "error", + "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", + } + + optimized_source = best_optimization.candidate.source_code.markdown + speedup = original_code_baseline.runtime / best_optimization.runtime + + server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") + + # CRITICAL: Clear the function filter after optimization to prevent state corruption + server.optimizer.args.function = None + server.show_message_log("Cleared function filter to prevent state corruption", "Info") - if not current_function: - server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") + return { + "functionName": params.functionName, + "status": "success", + "message": "Optimization completed successfully", + "extra": f"Speedup: {speedup:.2f}x faster", + "optimization": optimized_source, + } + + except Exception as e: + server.show_message_log(f"Error during optimization of {params.functionName}: {str(e)}", "Error") + + # Clean up worktrees in case of error + try: + if 'worktrees' in locals() and 'worktree_root' in locals() and worktrees and worktree_root: + server.show_message_log(f"Cleaning up worktrees after error for {params.functionName}", "Info") + remove_git_worktrees(worktree_root, worktrees) + except Exception as cleanup_e: + server.show_message_log(f"Failed to cleanup worktrees after error: {cleanup_e}", "Warning") + + # Still clear the function filter to prevent state corruption + try: + server.optimizer.args.function = None + except: + pass + return { "functionName": params.functionName, "status": "error", - "message": "No function currently being optimized", + "message": f"Optimization failed: {str(e)}", } - module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) - - validated_original_code, original_module_ast = module_prep_result - - function_optimizer = server.optimizer.create_function_optimizer( - current_function, - function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, - original_module_ast=original_module_ast, - original_module_path=current_function.file_path, - function_to_tests=server.optimizer.discovered_tests or {}, - ) - - server.optimizer.current_function_optimizer = function_optimizer - if not function_optimizer: - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} - initialization_result = function_optimizer.can_be_optimized() - if not is_successful(initialization_result): - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} +@dataclass +class WorktreeParams: + functionName: str # noqa: N815 + candidateId: str # noqa: N815 + gitRoot: str # noqa: N815 - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - test_setup_result = function_optimizer.generate_and_instrument_tests( - code_context, should_run_experiment=should_run_experiment - ) - if not is_successful(test_setup_result): - return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} - ( - generated_tests, - function_to_concolic_tests, - concolic_test_str, - optimizations_set, - generated_test_paths, - generated_perf_test_paths, - instrumented_unittests_created_for_function, - original_conftest_content, - ) = test_setup_result.unwrap() - - baseline_setup_result = function_optimizer.setup_and_establish_baseline( - code_context=code_context, - original_helper_code=original_helper_code, - function_to_concolic_tests=function_to_concolic_tests, - generated_test_paths=generated_test_paths, - generated_perf_test_paths=generated_perf_test_paths, - instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, - original_conftest_content=original_conftest_content, +@server.feature("codeflash/createWorktree") +def create_worktree(server: CodeflashLanguageServer, params: WorktreeParams) -> dict[str, str]: + """Create git worktrees for optimization suggestions using CLI's existing infrastructure.""" + server.show_message_log( + f"Creating worktree for function: {params.functionName}, candidate: {params.candidateId}", "Info" ) - if not is_successful(baseline_setup_result): - return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} - - ( - function_to_optimize_qualified_name, - function_to_all_tests, - original_code_baseline, - test_functions_to_remove, - file_path_to_helper_classes, - ) = baseline_setup_result.unwrap() - - best_optimization = function_optimizer.find_and_process_best_optimization( - optimizations_set=optimizations_set, - code_context=code_context, - original_code_baseline=original_code_baseline, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - function_to_optimize_qualified_name=function_to_optimize_qualified_name, - function_to_all_tests=function_to_all_tests, - generated_tests=generated_tests, - test_functions_to_remove=test_functions_to_remove, - concolic_test_str=concolic_test_str, - ) + try: + module_root = Path(params.gitRoot) + + # Create worktree root directory + git_root, worktree_root_dir = create_worktree_root_dir(module_root) + + if not git_root or not worktree_root_dir: + server.show_message_log("Not in a git repository, worktree creation skipped", "Warning") + return { + "functionName": params.functionName, + "candidateId": params.candidateId, + "status": "error", + "message": "Not in a git repository", + } + + # Create git worktrees (creates N_CANDIDATES + 1 worktrees) + worktree_root, worktrees = create_git_worktrees(git_root, worktree_root_dir, module_root) + + if not worktrees: + server.show_message_log("Failed to create git worktrees", "Error") + return { + "functionName": params.functionName, + "candidateId": params.candidateId, + "status": "error", + "message": "Failed to create git worktrees", + } + + # Store worktree info for later cleanup (use public attribute instead of private) + if not hasattr(server, "worktree_registry"): + server.worktree_registry = {} + + server.worktree_registry[params.candidateId] = { + "worktree_root": worktree_root, + "worktrees": worktrees, + "function_name": params.functionName, + } + + # For now, return the first worktree (original) - in a full implementation, + # you'd assign specific worktrees to specific optimization candidates + primary_worktree_path = str(worktrees[0]) if worktrees else str(worktree_root) - if not best_optimization: server.show_message_log( - f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" + f"Successfully created worktrees for {params.functionName}, primary at: {primary_worktree_path}", "Info" ) + + return { + "functionName": params.functionName, + "candidateId": params.candidateId, + "status": "success", + "worktreePath": primary_worktree_path, + "message": f"Created {len(worktrees)} worktrees", + } + + except Exception as e: + server.show_message_log(f"Error creating worktree: {e!s}", "Error") return { "functionName": params.functionName, + "candidateId": params.candidateId, "status": "error", - "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", + "message": f"Error creating worktree: {e!s}", } - optimized_source = best_optimization.candidate.source_code.markdown - speedup = original_code_baseline.runtime / best_optimization.runtime - server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") +@server.feature("codeflash/removeWorktree") +def remove_worktree(server: CodeflashLanguageServer, params: WorktreeParams) -> dict[str, str]: + """Remove git worktrees for a specific optimization candidate.""" + server.show_message_log(f"Removing worktree for candidate: {params.candidateId}", "Info") - # CRITICAL: Clear the function filter after optimization to prevent state corruption - server.optimizer.args.function = None - server.show_message_log("Cleared function filter to prevent state corruption", "Info") + if not hasattr(server, "worktree_registry") or params.candidateId not in server.worktree_registry: + server.show_message_log(f"No worktree found for candidate: {params.candidateId}", "Warning") + return {"candidateId": params.candidateId, "status": "warning", "message": "No worktree found for candidate"} - return { - "functionName": params.functionName, - "status": "success", - "message": "Optimization completed successfully", - "extra": f"Speedup: {speedup:.2f}x faster", - "optimization": optimized_source, - } + try: + worktree_info = server.worktree_registry[params.candidateId] + worktree_root = worktree_info["worktree_root"] + worktrees = worktree_info["worktrees"] + function_name = worktree_info["function_name"] + + # Use CLI's existing cleanup function + remove_git_worktrees(worktree_root, worktrees) + + # Remove from registry + del server.worktree_registry[params.candidateId] + + server.show_message_log( + f"Successfully removed worktrees for {function_name} (candidate: {params.candidateId})", "Info" + ) + + except Exception as e: + server.show_message_log(f"Error removing worktree: {e!s}", "Error") + return {"candidateId": params.candidateId, "status": "error", "message": f"Error removing worktree: {e!s}"} + else: + return { + "candidateId": params.candidateId, + "status": "success", + "message": f"Successfully removed worktrees for {function_name}", + }