|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import contextlib |
4 | | -import os |
| 3 | +import asyncio |
5 | 4 | from dataclasses import dataclass |
6 | 5 | from pathlib import Path |
7 | 6 | from typing import TYPE_CHECKING, Optional |
|
11 | 10 |
|
12 | 11 | from codeflash.api.cfapi import get_codeflash_api_key, get_user_id |
13 | 12 | from codeflash.cli_cmds.cli import process_pyproject_config |
14 | | -from codeflash.cli_cmds.console import code_print |
15 | | -from codeflash.code_utils.git_worktree_utils import ( |
16 | | - create_diff_patch_from_worktree, |
17 | | - get_patches_metadata, |
18 | | - overwrite_patch_metadata, |
19 | | -) |
| 13 | +from codeflash.code_utils.git_worktree_utils import get_patches_metadata, overwrite_patch_metadata |
20 | 14 | from codeflash.code_utils.shell_utils import save_api_key_to_rc |
21 | 15 | from codeflash.discovery.functions_to_optimize import ( |
22 | 16 | filter_functions, |
|
25 | 19 | ) |
26 | 20 | from codeflash.either import is_successful |
27 | 21 | from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol |
| 22 | +from codeflash.lsp.service.perform_optimization import sync_perform_optimization |
28 | 23 |
|
29 | 24 | if TYPE_CHECKING: |
30 | 25 | from argparse import Namespace |
@@ -274,154 +269,15 @@ def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedPar |
274 | 269 |
|
275 | 270 |
|
276 | 271 | @server.feature("performFunctionOptimization") |
277 | | -@server.thread() |
278 | | -def perform_function_optimization( # noqa: PLR0911 |
| 272 | +async def perform_function_optimization( |
279 | 273 | server: CodeflashLanguageServer, params: FunctionOptimizationParams |
280 | 274 | ) -> dict[str, str]: |
| 275 | + loop = asyncio.get_running_loop() |
281 | 276 | try: |
282 | | - server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") |
283 | | - current_function = server.optimizer.current_function_being_optimized |
284 | | - |
285 | | - if not current_function: |
286 | | - server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") |
287 | | - return { |
288 | | - "functionName": params.functionName, |
289 | | - "status": "error", |
290 | | - "message": "No function currently being optimized", |
291 | | - } |
292 | | - |
293 | | - module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) |
294 | | - if not module_prep_result: |
295 | | - return { |
296 | | - "functionName": params.functionName, |
297 | | - "status": "error", |
298 | | - "message": "Failed to prepare module for optimization", |
299 | | - } |
300 | | - |
301 | | - validated_original_code, original_module_ast = module_prep_result |
302 | | - |
303 | | - function_optimizer = server.optimizer.create_function_optimizer( |
304 | | - current_function, |
305 | | - function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, |
306 | | - original_module_ast=original_module_ast, |
307 | | - original_module_path=current_function.file_path, |
308 | | - function_to_tests={}, |
309 | | - ) |
310 | | - |
311 | | - server.optimizer.current_function_optimizer = function_optimizer |
312 | | - if not function_optimizer: |
313 | | - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} |
314 | | - |
315 | | - initialization_result = function_optimizer.can_be_optimized() |
316 | | - if not is_successful(initialization_result): |
317 | | - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} |
318 | | - |
319 | | - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() |
320 | | - |
321 | | - code_print( |
322 | | - code_context.read_writable_code.flat, |
323 | | - file_name=current_function.file_path, |
324 | | - function_name=current_function.function_name, |
325 | | - ) |
326 | | - |
327 | | - optimizable_funcs = {current_function.file_path: [current_function]} |
328 | | - |
329 | | - devnull_writer = open(os.devnull, "w") # noqa |
330 | | - with contextlib.redirect_stdout(devnull_writer): |
331 | | - function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs) |
332 | | - function_optimizer.function_to_tests = function_to_tests |
333 | | - |
334 | | - test_setup_result = function_optimizer.generate_and_instrument_tests( |
335 | | - code_context, should_run_experiment=should_run_experiment |
336 | | - ) |
337 | | - if not is_successful(test_setup_result): |
338 | | - return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()} |
339 | | - ( |
340 | | - generated_tests, |
341 | | - function_to_concolic_tests, |
342 | | - concolic_test_str, |
343 | | - optimizations_set, |
344 | | - generated_test_paths, |
345 | | - generated_perf_test_paths, |
346 | | - instrumented_unittests_created_for_function, |
347 | | - original_conftest_content, |
348 | | - ) = test_setup_result.unwrap() |
349 | | - |
350 | | - baseline_setup_result = function_optimizer.setup_and_establish_baseline( |
351 | | - code_context=code_context, |
352 | | - original_helper_code=original_helper_code, |
353 | | - function_to_concolic_tests=function_to_concolic_tests, |
354 | | - generated_test_paths=generated_test_paths, |
355 | | - generated_perf_test_paths=generated_perf_test_paths, |
356 | | - instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, |
357 | | - original_conftest_content=original_conftest_content, |
358 | | - ) |
359 | | - |
360 | | - if not is_successful(baseline_setup_result): |
361 | | - return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()} |
362 | | - |
363 | | - ( |
364 | | - function_to_optimize_qualified_name, |
365 | | - function_to_all_tests, |
366 | | - original_code_baseline, |
367 | | - test_functions_to_remove, |
368 | | - file_path_to_helper_classes, |
369 | | - ) = baseline_setup_result.unwrap() |
370 | | - |
371 | | - best_optimization = function_optimizer.find_and_process_best_optimization( |
372 | | - optimizations_set=optimizations_set, |
373 | | - code_context=code_context, |
374 | | - original_code_baseline=original_code_baseline, |
375 | | - original_helper_code=original_helper_code, |
376 | | - file_path_to_helper_classes=file_path_to_helper_classes, |
377 | | - function_to_optimize_qualified_name=function_to_optimize_qualified_name, |
378 | | - function_to_all_tests=function_to_all_tests, |
379 | | - generated_tests=generated_tests, |
380 | | - test_functions_to_remove=test_functions_to_remove, |
381 | | - concolic_test_str=concolic_test_str, |
382 | | - ) |
383 | | - |
384 | | - if not best_optimization: |
385 | | - server.show_message_log( |
386 | | - f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" |
387 | | - ) |
388 | | - return { |
389 | | - "functionName": params.functionName, |
390 | | - "status": "error", |
391 | | - "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", |
392 | | - } |
393 | | - |
394 | | - # generate a patch for the optimization |
395 | | - relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings] |
396 | | - |
397 | | - speedup = original_code_baseline.runtime / best_optimization.runtime |
398 | | - |
399 | | - # get the original file path in the actual project (not in the worktree) |
400 | | - original_args, _ = server.optimizer.original_args_and_test_cfg |
401 | | - relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree) |
402 | | - original_file_path = Path(original_args.project_root / relative_file_path).resolve() |
403 | | - |
404 | | - metadata = create_diff_patch_from_worktree( |
405 | | - server.optimizer.current_worktree, |
406 | | - relative_file_paths, |
407 | | - metadata_input={ |
408 | | - "fto_name": function_to_optimize_qualified_name, |
409 | | - "explanation": best_optimization.explanation_v2, |
410 | | - "file_path": str(original_file_path), |
411 | | - "speedup": speedup, |
412 | | - }, |
413 | | - ) |
414 | | - |
415 | | - server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") |
416 | | - |
417 | | - return { |
418 | | - "functionName": params.functionName, |
419 | | - "status": "success", |
420 | | - "message": "Optimization completed successfully", |
421 | | - "extra": f"Speedup: {speedup:.2f}x faster", |
422 | | - "patch_file": metadata["patch_path"], |
423 | | - "patch_id": metadata["id"], |
424 | | - "explanation": best_optimization.explanation_v2, |
425 | | - } |
| 277 | + result = await loop.run_in_executor(None, sync_perform_optimization, server, params) |
| 278 | + except asyncio.CancelledError: |
| 279 | + return {"status": "info", "message": "Task was forcefully canceled"} |
| 280 | + else: |
| 281 | + return result |
426 | 282 | finally: |
427 | 283 | server.cleanup_the_optimizer() |
0 commit comments