|
16 | 16 | from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff |
17 | 17 | from codeflash.either import is_successful |
18 | 18 | from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol |
19 | | -from codeflash.result.explanation import Explanation |
20 | 19 |
|
21 | 20 | if TYPE_CHECKING: |
22 | 21 | from lsprotocol import types |
23 | 22 |
|
24 | | - from codeflash.models.models import GeneratedTestsList, OptimizationSet |
25 | | - |
26 | 23 |
|
27 | 24 | @dataclass |
28 | 25 | class OptimizableFunctionsParams: |
@@ -219,67 +216,6 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams |
219 | 216 | return {"status": "error", "message": "something went wrong while saving the api key"} |
220 | 217 |
|
221 | 218 |
|
222 | | -@server.feature("prepareOptimization") |
223 | | -def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: |
224 | | - current_function = server.optimizer.current_function_being_optimized |
225 | | - |
226 | | - module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) |
227 | | - validated_original_code, original_module_ast = module_prep_result |
228 | | - |
229 | | - function_optimizer = server.optimizer.create_function_optimizer( |
230 | | - current_function, |
231 | | - function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, |
232 | | - original_module_ast=original_module_ast, |
233 | | - original_module_path=current_function.file_path, |
234 | | - ) |
235 | | - |
236 | | - server.optimizer.current_function_optimizer = function_optimizer |
237 | | - if not function_optimizer: |
238 | | - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} |
239 | | - |
240 | | - initialization_result = function_optimizer.can_be_optimized() |
241 | | - if not is_successful(initialization_result): |
242 | | - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} |
243 | | - |
244 | | - return {"functionName": params.functionName, "status": "success", "message": "Optimization preparation completed"} |
245 | | - |
246 | | - |
247 | | -@server.feature("generateTests") |
248 | | -def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]: |
249 | | - function_optimizer = server.optimizer.current_function_optimizer |
250 | | - if not function_optimizer: |
251 | | - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} |
252 | | - |
253 | | - initialization_result = function_optimizer.can_be_optimized() |
254 | | - if not is_successful(initialization_result): |
255 | | - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} |
256 | | - |
257 | | - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() |
258 | | - |
259 | | - test_setup_result = function_optimizer.generate_and_instrument_tests( |
260 | | - code_context, should_run_experiment=should_run_experiment |
261 | | - ) |
262 | | - if not is_successful(test_setup_result): |
263 | | - return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} |
264 | | - generated_tests_list: GeneratedTestsList |
265 | | - optimizations_set: OptimizationSet |
266 | | - generated_tests_list, _, concolic__test_str, optimizations_set = test_setup_result.unwrap() |
267 | | - |
268 | | - generated_tests: list[str] = [ |
269 | | - generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests |
270 | | - ] |
271 | | - optimizations_dict = { |
272 | | - candidate.optimization_id: {"source_code": candidate.source_code.markdown, "explanation": candidate.explanation} |
273 | | - for candidate in optimizations_set.control + optimizations_set.experiment |
274 | | - } |
275 | | - |
276 | | - return { |
277 | | - "functionName": params.functionName, |
278 | | - "status": "success", |
279 | | - "message": {"generated_tests": generated_tests, "optimizations": optimizations_dict}, |
280 | | - } |
281 | | - |
282 | | - |
283 | 219 | @server.feature("performFunctionOptimization") |
284 | 220 | def perform_function_optimization( # noqa: PLR0911 |
285 | 221 | server: CodeflashLanguageServer, params: FunctionOptimizationParams |
@@ -391,16 +327,14 @@ def perform_function_optimization( # noqa: PLR0911 |
391 | 327 |
|
392 | 328 | server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") |
393 | 329 |
|
394 | | - explanation = best_optimization.candidate.explanation |
395 | | - explanation_str = explanation.explanation_message() if isinstance(explanation, Explanation) else explanation |
396 | 330 | return { |
397 | 331 | "functionName": params.functionName, |
398 | 332 | "status": "success", |
399 | 333 | "message": "Optimization completed successfully", |
400 | 334 | "extra": f"Speedup: {speedup:.2f}x faster", |
401 | 335 | "optimization": optimized_source, |
402 | 336 | "patch_file": str(patch_file), |
403 | | - "explanation": explanation_str, |
| 337 | + "explanation": best_optimization.explanation_v2, |
404 | 338 | } |
405 | 339 | finally: |
406 | 340 | cleanup_the_optimizer(server) |
|
0 commit comments