Skip to content

Commit ba7e248

Browse files
committed
Make perform_function_optimization asynchronous
1 parent d3e6427 commit ba7e248

File tree

3 files changed

+152
-154
lines changed

3 files changed

+152
-154
lines changed

codeflash/lsp/beta.py

Lines changed: 10 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

3-
import contextlib
4-
import os
3+
import asyncio
54
from dataclasses import dataclass
65
from pathlib import Path
76
from typing import TYPE_CHECKING, Optional
@@ -11,12 +10,7 @@
1110

1211
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1312
from codeflash.cli_cmds.cli import process_pyproject_config
14-
from codeflash.cli_cmds.console import code_print
15-
from codeflash.code_utils.git_worktree_utils import (
16-
create_diff_patch_from_worktree,
17-
get_patches_metadata,
18-
overwrite_patch_metadata,
19-
)
13+
from codeflash.code_utils.git_worktree_utils import get_patches_metadata, overwrite_patch_metadata
2014
from codeflash.code_utils.shell_utils import save_api_key_to_rc
2115
from codeflash.discovery.functions_to_optimize import (
2216
filter_functions,
@@ -25,6 +19,7 @@
2519
)
2620
from codeflash.either import is_successful
2721
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
22+
from codeflash.lsp.service.perform_optimization import sync_perform_optimization
2823

2924
if TYPE_CHECKING:
3025
from argparse import Namespace
@@ -274,154 +269,15 @@ def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedPar
274269

275270

276271
@server.feature("performFunctionOptimization")
277-
@server.thread()
278-
def perform_function_optimization( # noqa: PLR0911
272+
async def perform_function_optimization(
279273
server: CodeflashLanguageServer, params: FunctionOptimizationParams
280274
) -> dict[str, str]:
275+
loop = asyncio.get_running_loop()
281276
try:
282-
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
283-
current_function = server.optimizer.current_function_being_optimized
284-
285-
if not current_function:
286-
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
287-
return {
288-
"functionName": params.functionName,
289-
"status": "error",
290-
"message": "No function currently being optimized",
291-
}
292-
293-
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
294-
if not module_prep_result:
295-
return {
296-
"functionName": params.functionName,
297-
"status": "error",
298-
"message": "Failed to prepare module for optimization",
299-
}
300-
301-
validated_original_code, original_module_ast = module_prep_result
302-
303-
function_optimizer = server.optimizer.create_function_optimizer(
304-
current_function,
305-
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
306-
original_module_ast=original_module_ast,
307-
original_module_path=current_function.file_path,
308-
function_to_tests={},
309-
)
310-
311-
server.optimizer.current_function_optimizer = function_optimizer
312-
if not function_optimizer:
313-
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
314-
315-
initialization_result = function_optimizer.can_be_optimized()
316-
if not is_successful(initialization_result):
317-
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
318-
319-
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
320-
321-
code_print(
322-
code_context.read_writable_code.flat,
323-
file_name=current_function.file_path,
324-
function_name=current_function.function_name,
325-
)
326-
327-
optimizable_funcs = {current_function.file_path: [current_function]}
328-
329-
devnull_writer = open(os.devnull, "w") # noqa
330-
with contextlib.redirect_stdout(devnull_writer):
331-
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
332-
function_optimizer.function_to_tests = function_to_tests
333-
334-
test_setup_result = function_optimizer.generate_and_instrument_tests(
335-
code_context, should_run_experiment=should_run_experiment
336-
)
337-
if not is_successful(test_setup_result):
338-
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
339-
(
340-
generated_tests,
341-
function_to_concolic_tests,
342-
concolic_test_str,
343-
optimizations_set,
344-
generated_test_paths,
345-
generated_perf_test_paths,
346-
instrumented_unittests_created_for_function,
347-
original_conftest_content,
348-
) = test_setup_result.unwrap()
349-
350-
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
351-
code_context=code_context,
352-
original_helper_code=original_helper_code,
353-
function_to_concolic_tests=function_to_concolic_tests,
354-
generated_test_paths=generated_test_paths,
355-
generated_perf_test_paths=generated_perf_test_paths,
356-
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
357-
original_conftest_content=original_conftest_content,
358-
)
359-
360-
if not is_successful(baseline_setup_result):
361-
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
362-
363-
(
364-
function_to_optimize_qualified_name,
365-
function_to_all_tests,
366-
original_code_baseline,
367-
test_functions_to_remove,
368-
file_path_to_helper_classes,
369-
) = baseline_setup_result.unwrap()
370-
371-
best_optimization = function_optimizer.find_and_process_best_optimization(
372-
optimizations_set=optimizations_set,
373-
code_context=code_context,
374-
original_code_baseline=original_code_baseline,
375-
original_helper_code=original_helper_code,
376-
file_path_to_helper_classes=file_path_to_helper_classes,
377-
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
378-
function_to_all_tests=function_to_all_tests,
379-
generated_tests=generated_tests,
380-
test_functions_to_remove=test_functions_to_remove,
381-
concolic_test_str=concolic_test_str,
382-
)
383-
384-
if not best_optimization:
385-
server.show_message_log(
386-
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
387-
)
388-
return {
389-
"functionName": params.functionName,
390-
"status": "error",
391-
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
392-
}
393-
394-
# generate a patch for the optimization
395-
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
396-
397-
speedup = original_code_baseline.runtime / best_optimization.runtime
398-
399-
# get the original file path in the actual project (not in the worktree)
400-
original_args, _ = server.optimizer.original_args_and_test_cfg
401-
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
402-
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
403-
404-
metadata = create_diff_patch_from_worktree(
405-
server.optimizer.current_worktree,
406-
relative_file_paths,
407-
metadata_input={
408-
"fto_name": function_to_optimize_qualified_name,
409-
"explanation": best_optimization.explanation_v2,
410-
"file_path": str(original_file_path),
411-
"speedup": speedup,
412-
},
413-
)
414-
415-
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
416-
417-
return {
418-
"functionName": params.functionName,
419-
"status": "success",
420-
"message": "Optimization completed successfully",
421-
"extra": f"Speedup: {speedup:.2f}x faster",
422-
"patch_file": metadata["patch_path"],
423-
"patch_id": metadata["id"],
424-
"explanation": best_optimization.explanation_v2,
425-
}
277+
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
278+
except asyncio.CancelledError:
279+
return {"status": "info", "message": "Task was forcefully canceled"}
280+
else:
281+
return result
426282
finally:
427283
server.cleanup_the_optimizer()

codeflash/lsp/service/__init__.py

Whitespace-only changes.
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import contextlib
2+
import os
3+
from pathlib import Path
4+
5+
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
6+
from codeflash.either import is_successful
7+
from codeflash.lsp.server import CodeflashLanguageServer
8+
9+
10+
# ruff: noqa: PLR0911, ANN001
11+
def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[str, str]:
12+
current_function = server.optimizer.current_function_being_optimized
13+
if not current_function:
14+
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
15+
return {
16+
"functionName": params.functionName,
17+
"status": "error",
18+
"message": "No function currently being optimized",
19+
}
20+
21+
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
22+
if not module_prep_result:
23+
return {
24+
"functionName": params.functionName,
25+
"status": "error",
26+
"message": "Failed to prepare module for optimization",
27+
}
28+
29+
validated_original_code, original_module_ast = module_prep_result
30+
31+
function_optimizer = server.optimizer.create_function_optimizer(
32+
current_function,
33+
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
34+
original_module_ast=original_module_ast,
35+
original_module_path=current_function.file_path,
36+
function_to_tests={},
37+
)
38+
server.optimizer.current_function_optimizer = function_optimizer
39+
if not function_optimizer:
40+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
41+
42+
initialization_result = function_optimizer.can_be_optimized()
43+
if not is_successful(initialization_result):
44+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
45+
46+
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
47+
48+
# All the synchronous, potentially blocking calls
49+
optimizable_funcs = {current_function.file_path: [current_function]}
50+
devnull_writer = open(os.devnull, "w") # noqa
51+
with contextlib.redirect_stdout(devnull_writer):
52+
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
53+
function_optimizer.function_to_tests = function_to_tests
54+
55+
test_setup_result = function_optimizer.generate_and_instrument_tests(
56+
code_context, should_run_experiment=should_run_experiment
57+
)
58+
if not is_successful(test_setup_result):
59+
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
60+
61+
(
62+
generated_tests,
63+
function_to_concolic_tests,
64+
concolic_test_str,
65+
optimizations_set,
66+
generated_test_paths,
67+
generated_perf_test_paths,
68+
instrumented_unittests_created_for_function,
69+
original_conftest_content,
70+
) = test_setup_result.unwrap()
71+
72+
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
73+
code_context=code_context,
74+
original_helper_code=original_helper_code,
75+
function_to_concolic_tests=function_to_concolic_tests,
76+
generated_test_paths=generated_test_paths,
77+
generated_perf_test_paths=generated_perf_test_paths,
78+
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
79+
original_conftest_content=original_conftest_content,
80+
)
81+
82+
if not is_successful(baseline_setup_result):
83+
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
84+
85+
(
86+
function_to_optimize_qualified_name,
87+
function_to_all_tests,
88+
original_code_baseline,
89+
test_functions_to_remove,
90+
file_path_to_helper_classes,
91+
) = baseline_setup_result.unwrap()
92+
93+
best_optimization = function_optimizer.find_and_process_best_optimization(
94+
optimizations_set=optimizations_set,
95+
code_context=code_context,
96+
original_code_baseline=original_code_baseline,
97+
original_helper_code=original_helper_code,
98+
file_path_to_helper_classes=file_path_to_helper_classes,
99+
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
100+
function_to_all_tests=function_to_all_tests,
101+
generated_tests=generated_tests,
102+
test_functions_to_remove=test_functions_to_remove,
103+
concolic_test_str=concolic_test_str,
104+
)
105+
106+
if not best_optimization:
107+
server.show_message_log(
108+
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
109+
)
110+
return {
111+
"functionName": params.functionName,
112+
"status": "error",
113+
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
114+
}
115+
116+
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
117+
speedup = original_code_baseline.runtime / best_optimization.runtime
118+
original_args, _ = server.optimizer.original_args_and_test_cfg
119+
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
120+
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
121+
122+
metadata = create_diff_patch_from_worktree(
123+
server.optimizer.current_worktree,
124+
relative_file_paths,
125+
metadata_input={
126+
"fto_name": function_to_optimize_qualified_name,
127+
"explanation": best_optimization.explanation_v2,
128+
"file_path": str(original_file_path),
129+
"speedup": speedup,
130+
},
131+
)
132+
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
133+
134+
return {
135+
"functionName": params.functionName,
136+
"status": "success",
137+
"message": "Optimization completed successfully",
138+
"extra": f"Speedup: {speedup:.2f}x faster",
139+
"patch_file": metadata["patch_path"],
140+
"patch_id": metadata["id"],
141+
"explanation": best_optimization.explanation_v2,
142+
}

0 commit comments

Comments
 (0)