Skip to content

Commit f2f394d

Browse files
committed
prep
1 parent ffefcc9 commit f2f394d

File tree

2 files changed

+73
-8
lines changed

2 files changed

+73
-8
lines changed

codeflash/lsp/beta.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
if TYPE_CHECKING:
1313
from lsprotocol import types
1414

15+
from codeflash.models.models import GeneratedTestsList, OptimizationSet
16+
1517

1618
@dataclass
1719
class OptimizableFunctionsParams:
@@ -67,6 +69,67 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt
6769
return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)}
6870

6971

72+
@server.feature("prepareOptimization")
73+
def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
74+
current_function = server.optimizer.current_function_being_optimized
75+
76+
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
77+
validated_original_code, original_module_ast = module_prep_result
78+
79+
function_optimizer = server.optimizer.create_function_optimizer(
80+
current_function,
81+
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
82+
original_module_ast=original_module_ast,
83+
original_module_path=current_function.file_path,
84+
)
85+
86+
server.optimizer.current_function_optimizer = function_optimizer
87+
if not function_optimizer:
88+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
89+
90+
initialization_result = function_optimizer.can_be_optimized()
91+
if not is_successful(initialization_result):
92+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
93+
94+
return {"functionName": params.functionName, "status": "success", "message": "Optimization preparation completed"}
95+
96+
97+
@server.feature("generateTests")
98+
def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
99+
function_optimizer = server.optimizer.current_function_optimizer
100+
if not function_optimizer:
101+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
102+
103+
initialization_result = function_optimizer.can_be_optimized()
104+
if not is_successful(initialization_result):
105+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
106+
107+
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
108+
109+
test_setup_result = function_optimizer.generate_and_instrument_tests(
110+
code_context, should_run_experiment=should_run_experiment
111+
)
112+
if not is_successful(test_setup_result):
113+
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
114+
generated_tests_list: GeneratedTestsList
115+
optimizations_set: OptimizationSet
116+
generated_tests_list, _, concolic__test_str, optimizations_set = test_setup_result.unwrap()
117+
118+
generated_tests: list[str] = [
119+
generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests
120+
]
121+
optimizations_dict = {
122+
candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation}
123+
for candidate in optimizations_set.control + optimizations_set.experiment
124+
}
125+
126+
return {
127+
"functionName": params.functionName,
128+
"status": "success",
129+
"message": {"generated_tests": generated_tests, "optimizations": optimizations_dict},
130+
}
131+
132+
70133
@server.feature("performFunctionOptimization")
71134
def perform_function_optimization(
72135
server: CodeflashLanguageServer, params: FunctionOptimizationParams

codeflash/optimization/function_optimizer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def __init__(
143143
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
144144
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
145145
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
146+
self.generate_and_instrument_tests_results: (
147+
tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet] | None
148+
) = None
146149

147150
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
148151
should_run_experiment = self.experiment_id is not None
@@ -840,15 +843,14 @@ def generate_tests_and_optimizations(
840843
logger.info(f"Generated {len(tests)} tests for {self.function_to_optimize.function_name}")
841844
console.rule()
842845
generated_tests = GeneratedTestsList(generated_tests=tests)
843-
844-
return Success(
845-
(
846-
generated_tests,
847-
function_to_concolic_tests,
848-
concolic_test_str,
849-
OptimizationSet(control=candidates, experiment=candidates_experiment),
850-
)
846+
result = (
847+
generated_tests,
848+
function_to_concolic_tests,
849+
concolic_test_str,
850+
OptimizationSet(control=candidates, experiment=candidates_experiment),
851851
)
852+
self.generate_and_instrument_tests_results = result
853+
return Success(result)
852854

853855
def setup_and_establish_baseline(
854856
self,

0 commit comments

Comments
 (0)