Skip to content

Commit b83bd97

Browse files
committed
relevant context for new explanations
1 parent f3aaab4 commit b83bd97

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

codeflash/api/aiservice.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,10 @@ def optimize_python_code_line_profiler(
207207
def get_new_explanation(
208208
self,
209209
source_code: str,
210+
optimized_code: str,
210211
dependency_code: str,
211212
trace_id: str,
212-
num_candidates: int = 10,
213-
experiment_metadata: ExperimentMetadata | None = None,
214-
existing_explanation: str = "",
213+
existing_explanation: str,
215214
) -> str:
216215
"""Optimize the given python code for performance by making a request to the Django endpoint.
217216
@@ -230,16 +229,12 @@ def get_new_explanation(
230229
231230
"""
232231
payload = {
233-
"source_code": source_code,
234-
"dependency_code": dependency_code,
235-
"num_variants": num_candidates,
236232
"trace_id": trace_id,
237-
"python_version": platform.python_version(),
238-
"experiment_metadata": experiment_metadata,
239-
"codeflash_version": codeflash_version,
233+
"source_code": source_code,
234+
"optimized_code":optimized_code,
240235
"existing_explanation": existing_explanation,
236+
"dependency_code": dependency_code,
241237
}
242-
243238
logger.info("Generating optimized candidates…")
244239
console.rule()
245240
try:

codeflash/optimization/function_optimizer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import time
99
import uuid
1010
from collections import defaultdict, deque
11+
from dataclasses import replace
1112
from pathlib import Path
1213
from typing import TYPE_CHECKING
1314

@@ -254,29 +255,28 @@ def optimize_function(self) -> Result[BestOptimization, str]:
254255
)
255256

256257
if best_optimization:
257-
logger.info("Best candidate:")
258-
code_print(best_optimization.candidate.source_code)
259-
console.print(
260-
Panel(
261-
best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue"
262-
)
263-
)
264-
#could possibly have it in the best optimization dataclass
265258
new_explanation = self.aiservice_client.get_new_explanation(source_code=code_context.read_writable_code,
266259
dependency_code=code_context.read_only_context_code,
267260
trace_id=self.function_trace_id,
268261
num_candidates=1,
269262
experiment_metadata=None, existing_explanation=best_optimization.candidate.explanation)
263+
best_optimization.candidate = replace(best_optimization.candidate, explanation=new_explanation if new_explanation!="" else best_optimization.candidate.explanation)
270264
explanation = Explanation(
271-
raw_explanation_message=new_explanation if new_explanation!="" else best_optimization.candidate.explanation,
265+
raw_explanation_message=best_optimization.candidate.explanation,
272266
winning_behavioral_test_results=best_optimization.winning_behavioral_test_results,
273267
winning_benchmarking_test_results=best_optimization.winning_benchmarking_test_results,
274268
original_runtime_ns=original_code_baseline.runtime,
275269
best_runtime_ns=best_optimization.runtime,
276270
function_name=function_to_optimize_qualified_name,
277271
file_path=self.function_to_optimize.file_path,
278272
)
279-
273+
logger.info("Best candidate:")
274+
code_print(best_optimization.candidate.source_code)
275+
console.print(
276+
Panel(
277+
best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue"
278+
)
279+
)
280280
self.log_successful_optimization(explanation, generated_tests)
281281

282282
self.replace_function_and_helpers_with_optimized_code(

tests/test_explain_api.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from codeflash.api.aiservice import AiServiceClient
2-
from codeflash.models.ExperimentMetadata import ExperimentMetadata
32
def test_explain_api():
43
aiservice = AiServiceClient()
5-
source_code: str = "a"
6-
dependency_code: str = "b"
4+
source_code: str = """def bubble_sort(arr):
5+
n = len(arr)
6+
for i in range(n):
7+
for j in range(0, n-i-1):
8+
if arr[j] > arr[j+1]:
9+
arr[j], arr[j+1] = arr[j+1], arr[j]
10+
return arr
11+
"""
12+
dependency_code: str = "def helper(): return 1"
713
trace_id: str = "d5822364-7617-4389-a4fc-64602a00b714"
8-
num_candidates: int = 1
9-
experiment_metadata: ExperimentMetadata | None = None
10-
existing_explanation: str = "some explanation"
11-
new_explanation = aiservice.get_new_explanation(source_code=source_code,
12-
dependency_code=dependency_code,
13-
trace_id=trace_id,
14-
num_candidates=num_candidates,
15-
experiment_metadata=experiment_metadata, existing_explanation=existing_explanation)
14+
existing_explanation: str = "I used to numpy to optimize it"
15+
optimized_code: str = """def bubble_sort(arr):
16+
return arr.sort()
17+
"""
18+
new_explanation = aiservice.get_new_explanation(source_code=source_code, optimized_code=optimized_code,
19+
existing_explanation=existing_explanation, dependency_code=dependency_code,
20+
trace_id=trace_id)
21+
print("\nNew explanation: \n", new_explanation)
1622
assert new_explanation.__len__()>0

0 commit comments

Comments
 (0)