diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 57a65723a..b249025d4 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -1,6 +1,6 @@ from __future__ import annotations -import contextlib +import asyncio import os from dataclasses import dataclass from pathlib import Path @@ -11,9 +11,7 @@ from codeflash.api.cfapi import get_codeflash_api_key, get_user_id from codeflash.cli_cmds.cli import process_pyproject_config -from codeflash.cli_cmds.console import code_print from codeflash.code_utils.git_utils import git_root_dir -from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree from codeflash.code_utils.shell_utils import save_api_key_to_rc from codeflash.discovery.functions_to_optimize import ( filter_functions, @@ -21,6 +19,7 @@ get_functions_within_git_diff, ) from codeflash.either import is_successful +from codeflash.lsp.features.perform_optimization import sync_perform_optimization from codeflash.lsp.server import CodeflashLanguageServer if TYPE_CHECKING: @@ -71,7 +70,6 @@ class OptimizableFunctionsInCommitParams: commit_hash: str -# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) server = CodeflashLanguageServer("codeflash-language-server", "v1.0") @@ -339,115 +337,15 @@ def initialize_function_optimization( @server.feature("performFunctionOptimization") -@server.thread() -def perform_function_optimization( +async def perform_function_optimization( server: CodeflashLanguageServer, params: FunctionOptimizationParams ) -> dict[str, str]: + loop = asyncio.get_running_loop() try: - server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") - should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result - function_optimizer = server.optimizer.current_function_optimizer - current_function = function_optimizer.function_to_optimize - - code_print( - code_context.read_writable_code.flat, - file_name=current_function.file_path, - function_name=current_function.function_name, - ) - - optimizable_funcs = {current_function.file_path: [current_function]} - - devnull_writer = open(os.devnull, "w") # noqa - with contextlib.redirect_stdout(devnull_writer): - function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) - function_optimizer.function_to_tests = function_to_tests - - test_setup_result = function_optimizer.generate_and_instrument_tests( - code_context, should_run_experiment=should_run_experiment - ) - if not is_successful(test_setup_result): - return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} - ( - generated_tests, - function_to_concolic_tests, - concolic_test_str, - optimizations_set, - generated_test_paths, - generated_perf_test_paths, - instrumented_unittests_created_for_function, - original_conftest_content, - ) = test_setup_result.unwrap() - - baseline_setup_result = function_optimizer.setup_and_establish_baseline( - code_context=code_context, - original_helper_code=original_helper_code, - function_to_concolic_tests=function_to_concolic_tests, - generated_test_paths=generated_test_paths, - generated_perf_test_paths=generated_perf_test_paths, - instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, - original_conftest_content=original_conftest_content, - ) - - if not is_successful(baseline_setup_result): - return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} - - ( - function_to_optimize_qualified_name, - function_to_all_tests, - original_code_baseline, - test_functions_to_remove, - file_path_to_helper_classes, - ) = baseline_setup_result.unwrap() - - best_optimization = function_optimizer.find_and_process_best_optimization( - optimizations_set=optimizations_set, - code_context=code_context, - original_code_baseline=original_code_baseline, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - function_to_optimize_qualified_name=function_to_optimize_qualified_name, - function_to_all_tests=function_to_all_tests, - generated_tests=generated_tests, - test_functions_to_remove=test_functions_to_remove, - concolic_test_str=concolic_test_str, - ) - - if not best_optimization: - server.show_message_log( - f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" - ) - return { - "functionName": params.functionName, - "status": "error", - "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", - } - - # generate a patch for the optimization - relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings] - - speedup = original_code_baseline.runtime / best_optimization.runtime - - patch_path = create_diff_patch_from_worktree( - server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name - ) - - if not patch_path: - return { - "functionName": params.functionName, - "status": "error", - "message": "Failed to create a patch for optimization", - } - - server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") - - return { - "functionName": params.functionName, - "status": "success", - "message": "Optimization completed successfully", - "extra": f"Speedup: {speedup:.2f}x faster", - "patch_file": str(patch_path), - "task_id": params.task_id, - "explanation": best_optimization.explanation_v2, - } + result = await loop.run_in_executor(None, sync_perform_optimization, server, params) + except asyncio.CancelledError: + return {"status": "canceled", "message": "Task was canceled"} + else: + return result finally: server.cleanup_the_optimizer() diff --git a/codeflash/lsp/features/__init__.py b/codeflash/lsp/features/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/lsp/features/perform_optimization.py b/codeflash/lsp/features/perform_optimization.py new file mode 100644 index 000000000..4ca11c5ab --- /dev/null +++ b/codeflash/lsp/features/perform_optimization.py @@ -0,0 +1,152 @@ +import contextlib +import os +from pathlib import Path + +from codeflash.cli_cmds.console import code_print +from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree +from codeflash.either import is_successful +from codeflash.lsp.server import CodeflashLanguageServer + + +# ruff: noqa: PLR0911, ANN001 +def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[str, str]: + server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") + current_function = server.optimizer.current_function_being_optimized + + if not current_function: + server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") + return { + "functionName": params.functionName, + "status": "error", + "message": "No function currently being optimized", + } + + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) + if not module_prep_result: + return { + "functionName": params.functionName, + "status": "error", + "message": "Failed to prepare module for optimization", + } + + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = server.optimizer.create_function_optimizer( + current_function, + function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=current_function.file_path, + function_to_tests={}, + ) + + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + + code_print( + code_context.read_writable_code.flat, + file_name=current_function.file_path, + function_name=current_function.function_name, + ) + + optimizable_funcs = {current_function.file_path: [current_function]} + + devnull_writer = open(os.devnull, "w") # noqa + with contextlib.redirect_stdout(devnull_writer): + function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) + function_optimizer.function_to_tests = function_to_tests + + test_setup_result = function_optimizer.generate_and_instrument_tests( + code_context, should_run_experiment=should_run_experiment + ) + if not is_successful(test_setup_result): + return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} + ( + generated_tests, + function_to_concolic_tests, + concolic_test_str, + optimizations_set, + generated_test_paths, + generated_perf_test_paths, + instrumented_unittests_created_for_function, + original_conftest_content, + ) = test_setup_result.unwrap() + + baseline_setup_result = function_optimizer.setup_and_establish_baseline( + code_context=code_context, + original_helper_code=original_helper_code, + function_to_concolic_tests=function_to_concolic_tests, + generated_test_paths=generated_test_paths, + generated_perf_test_paths=generated_perf_test_paths, + instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, + original_conftest_content=original_conftest_content, + ) + + if not is_successful(baseline_setup_result): + return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} + + ( + function_to_optimize_qualified_name, + function_to_all_tests, + original_code_baseline, + test_functions_to_remove, + file_path_to_helper_classes, + ) = baseline_setup_result.unwrap() + + best_optimization = function_optimizer.find_and_process_best_optimization( + optimizations_set=optimizations_set, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + function_to_optimize_qualified_name=function_to_optimize_qualified_name, + function_to_all_tests=function_to_all_tests, + generated_tests=generated_tests, + test_functions_to_remove=test_functions_to_remove, + concolic_test_str=concolic_test_str, + ) + + if not best_optimization: + server.show_message_log( + f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" + ) + return { + "functionName": params.functionName, + "status": "error", + "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", + } + # generate a patch for the optimization + relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings] + speedup = original_code_baseline.runtime / best_optimization.runtime + # get the original file path in the actual project (not in the worktree) + original_args, _ = server.optimizer.original_args_and_test_cfg + relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree) + original_file_path = Path(original_args.project_root / relative_file_path).resolve() + + metadata = create_diff_patch_from_worktree( + server.optimizer.current_worktree, + relative_file_paths, + metadata_input={ + "fto_name": function_to_optimize_qualified_name, + "explanation": best_optimization.explanation_v2, + "file_path": str(original_file_path), + "speedup": speedup, + }, + ) + + server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") + return { + "functionName": params.functionName, + "status": "success", + "message": "Optimization completed successfully", + "extra": f"Speedup: {speedup:.2f}x faster", + "patch_file": metadata["patch_path"], + "patch_id": metadata["id"], + "explanation": best_optimization.explanation_v2, + }