Skip to content

Commit 3eedbd2

Browse files
refi optimization ids and original optimization ids
1 parent ed6b5b1 commit 3eedbd2

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

codeflash/api/aiservice.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
@dataclass(frozen=True)
2828
class AIServiceRefinerRequest:
29+
optimization_id: str
2930
original_source_code: str
3031
read_only_dependency_code: str
3132
original_code_runtime: str
@@ -232,10 +233,11 @@ def optimize_python_code_line_profiler( # noqa: D417
232233
console.rule()
233234
return []
234235

235-
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[str]:
236+
def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> dict[str, str]:
236237
git_repo_owner, git_repo_name = safe_get_repo_owner_and_name()
237238
payload = [
238239
{
240+
"optimization_id": opt.optimization_id,
239241
"original_source_code": opt.original_source_code,
240242
"read_only_dependency_code": opt.read_only_dependency_code,
241243
"original_line_profiler_results": opt.original_line_profiler_results,
@@ -280,7 +282,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
280282
except requests.exceptions.RequestException as e:
281283
logger.exception(f"Error generating optimization refinements: {e}")
282284
ph("cli-optimize-error-caught", {"error": str(e)})
283-
return []
285+
return {}
284286

285287
if response.status_code == 200:
286288
refined_optimizations = response.json()["result"]
@@ -294,7 +296,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
294296
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
295297
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
296298
console.rule()
297-
return []
299+
return {}
298300

299301
def log_results( # noqa: D417
300302
self,
@@ -303,6 +305,7 @@ def log_results( # noqa: D417
303305
original_runtime: float | None,
304306
optimized_runtime: dict[str, float | None] | None,
305307
is_correct: dict[str, bool] | None,
308+
metadata: dict[str, any] | None,
306309
) -> None:
307310
"""Log features to the database.
308311
@@ -313,6 +316,7 @@ def log_results( # noqa: D417
313316
- original_runtime (Optional[Dict[str, float]]): The original runtime.
314317
- optimized_runtime (Optional[Dict[str, float]]): The optimized runtime.
315318
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
319+
- metadata (Optional[dict[str, any]]): metadata.
316320
317321
"""
318322
payload = {
@@ -322,6 +326,7 @@ def log_results( # noqa: D417
322326
"optimized_runtime": optimized_runtime,
323327
"is_correct": is_correct,
324328
"codeflash_version": codeflash_version,
329+
"metadata": metadata,
325330
}
326331
try:
327332
self.make_ai_service_request("/log_features", payload=payload, timeout=5)

codeflash/optimization/function_optimizer.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,8 @@ def determine_best_candidate(
547547
trace_id = self.function_trace_id
548548
if trace_id.endswith(("EXP0", "EXP1")):
549549
trace_id = trace_id[:-4] + exp_type
550-
refinement_diffs = self.refine_optimizations(
550+
# refinement_dict is a dictionary with optimization_id as a key and the refined code as a value
551+
refinement_dict = self.refine_optimizations(
551552
valid_optimizations=self.valid_optimizations,
552553
original_code_baseline=original_code_baseline,
553554
code_context=code_context,
@@ -561,15 +562,18 @@ def determine_best_candidate(
561562
executor=executor,
562563
fto_name=self.function_to_optimize.qualified_name,
563564
)
564-
# filter out empty strings of code
565+
565566
more_opt_candidates = [
566567
OptimizedCandidate(
567-
source_code=refinement_diffs[i],
568-
explanation=self.valid_optimizations[i].candidate.explanation,
569-
optimization_id=self.valid_optimizations[i].candidate.optimization_id,
568+
source_code=code,
569+
explanation=self.valid_optimizations[
570+
i
571+
].candidate.explanation, # TODO: handle the new explanation after the refinement
572+
optimization_id=opt_id,
570573
)
571-
for i in range(len(refinement_diffs))
572-
if refinement_diffs[i] != ""
574+
for i, (opt_id, code) in enumerate(refinement_dict.items())
575+
# filter out empty strings of code
576+
if code != ""
573577
]
574578
# we no longer need to apply diffs since we are generating the entire code again
575579
candidates.extend(more_opt_candidates)
@@ -637,14 +641,16 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
637641
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
638642
overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking.keys()} # noqa: SIM118
639643
min_key = min(overall_ranking, key=overall_ranking.get)
644+
best_optimization = self.valid_optimizations[min_key]
640645
ai_service_client.log_results(
641646
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
642647
speedup_ratio=speedup_ratios,
643648
original_runtime=original_code_baseline.runtime,
644649
optimized_runtime=optimized_runtimes,
645650
is_correct=is_correct,
651+
metadata={"best_optimization_id": best_optimization.candidate.optimization_id},
646652
)
647-
return self.valid_optimizations[min_key]
653+
return best_optimization
648654

649655
def refine_optimizations(
650656
self,
@@ -656,9 +662,10 @@ def refine_optimizations(
656662
ai_service_client: AiServiceClient,
657663
executor: concurrent.futures.ThreadPoolExecutor,
658664
fto_name: str,
659-
) -> list[str]:
665+
) -> dict[str, str]:
660666
request = [
661667
AIServiceRefinerRequest(
668+
optimization_id=opt.candidate.optimization_id,
662669
original_source_code=code_context.read_writable_code,
663670
read_only_dependency_code=code_context.read_only_context_code,
664671
original_code_runtime=humanize_runtime(original_code_baseline.runtime),

0 commit comments

Comments
 (0)