|
12 | 12 | if TYPE_CHECKING: |
13 | 13 | from lsprotocol import types |
14 | 14 |
|
| 15 | + from codeflash.models.models import GeneratedTestsList, OptimizationSet |
| 16 | + |
15 | 17 |
|
16 | 18 | @dataclass |
17 | 19 | class OptimizableFunctionsParams: |
@@ -67,6 +69,67 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt |
67 | 69 | return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)} |
68 | 70 |
|
69 | 71 |
|
| 72 | +@server.feature("prepareOptimization") |
| 73 | +def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: |
| 74 | + current_function = server.optimizer.current_function_being_optimized |
| 75 | + |
| 76 | + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) |
| 77 | + validated_original_code, original_module_ast = module_prep_result |
| 78 | + |
| 79 | + function_optimizer = server.optimizer.create_function_optimizer( |
| 80 | + current_function, |
| 81 | + function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, |
| 82 | + original_module_ast=original_module_ast, |
| 83 | + original_module_path=current_function.file_path, |
| 84 | + ) |
| 85 | + |
| 86 | + server.optimizer.current_function_optimizer = function_optimizer |
| 87 | + if not function_optimizer: |
| 88 | + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} |
| 89 | + |
| 90 | + initialization_result = function_optimizer.can_be_optimized() |
| 91 | + if not is_successful(initialization_result): |
| 92 | + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} |
| 93 | + |
| 94 | + return {"functionName": params.functionName, "status": "success", "message": "Optimization preparation completed"} |
| 95 | + |
| 96 | + |
| 97 | +@server.feature("generateTests") |
| 98 | +def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: |
| 99 | + function_optimizer = server.optimizer.current_function_optimizer |
| 100 | + if not function_optimizer: |
| 101 | + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} |
| 102 | + |
| 103 | + initialization_result = function_optimizer.can_be_optimized() |
| 104 | + if not is_successful(initialization_result): |
| 105 | + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} |
| 106 | + |
| 107 | + should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() |
| 108 | + |
| 109 | + test_setup_result = function_optimizer.generate_and_instrument_tests( |
| 110 | + code_context, should_run_experiment=should_run_experiment |
| 111 | + ) |
| 112 | + if not is_successful(test_setup_result): |
| 113 | + return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} |
| 114 | + generated_tests_list: GeneratedTestsList |
| 115 | + optimizations_set: OptimizationSet |
| 116 | + generated_tests_list, _, concolic__test_str, optimizations_set = test_setup_result.unwrap() |
| 117 | + |
| 118 | + generated_tests: list[str] = [ |
| 119 | + generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests |
| 120 | + ] |
| 121 | + optimizations_dict = { |
| 122 | + candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation} |
| 123 | + for candidate in optimizations_set.control + optimizations_set.experiment |
| 124 | + } |
| 125 | + |
| 126 | + return { |
| 127 | + "functionName": params.functionName, |
| 128 | + "status": "success", |
| 129 | + "message": {"generated_tests": generated_tests, "optimizations": optimizations_dict}, |
| 130 | + } |
| 131 | + |
| 132 | + |
70 | 133 | @server.feature("performFunctionOptimization") |
71 | 134 | def perform_function_optimization( |
72 | 135 | server: CodeflashLanguageServer, params: FunctionOptimizationParams |
|
0 commit comments