Skip to content

Commit 277b1b6

Browse files
only call adaptive optimization on previously refined candidates
1 parent 7b77477 commit 277b1b6

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -934,22 +934,23 @@ def process_single_candidate(
934934
eval_ctx=eval_ctx,
935935
)
936936
eval_ctx.valid_optimizations.append(best_optimization)
937-
# Queue adaptive optimization
938-
future_adaptive_optimization = self.call_adaptive_optimize(
939-
trace_id=self.get_trace_id(exp_type),
940-
original_source_code=code_context.read_writable_code.markdown,
941-
candidate_node=candidate_node,
942-
eval_ctx=eval_ctx,
943-
ai_service_client=self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client,
944-
)
945-
if future_adaptive_optimization:
946-
self.future_adaptive_optimizations.append(future_adaptive_optimization)
947937

948-
# Queue refinement for non-refined candidates
938+
current_tree_candidates = candidate_node.path_to_root()
949939
is_candidate_refined_before = any(
950-
c.source == OptimizedCandidateSource.REFINE for c in candidate_node.path_to_root()
940+
c.source == OptimizedCandidateSource.REFINE for c in current_tree_candidates
951941
)
952-
if not is_candidate_refined_before:
942+
943+
if is_candidate_refined_before:
944+
future_adaptive_optimization = self.call_adaptive_optimize(
945+
trace_id=self.get_trace_id(exp_type),
946+
original_source_code=code_context.read_writable_code.markdown,
947+
prev_candidates=current_tree_candidates,
948+
eval_ctx=eval_ctx,
949+
ai_service_client=self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client,
950+
)
951+
if future_adaptive_optimization:
952+
self.future_adaptive_optimizations.append(future_adaptive_optimization)
953+
else:
953954
all_refinements_data.append(
954955
AIServiceRefinerRequest(
955956
optimization_id=best_optimization.candidate.optimization_id,
@@ -1085,7 +1086,7 @@ def call_adaptive_optimize(
10851086
self,
10861087
trace_id: str,
10871088
original_source_code: str,
1088-
candidate_node: CandidateNode,
1089+
prev_candidates: list[OptimizedCandidate],
10891090
eval_ctx: CandidateEvaluationContext,
10901091
ai_service_client: AiServiceClient,
10911092
) -> concurrent.futures.Future[OptimizedCandidate | None] | None:
@@ -1095,11 +1096,6 @@ def call_adaptive_optimize(
10951096
)
10961097
return None
10971098

1098-
prev_candidates = candidate_node.path_to_root()
1099-
if len(prev_candidates) == 1:
1100-
# we already have the refinement going for this single candidate tree, no need to do adaptive optimize
1101-
return None
1102-
11031099
adaptive_count = sum(1 for c in prev_candidates if c.source == OptimizedCandidateSource.ADAPTIVE)
11041100

11051101
if adaptive_count >= ADAPTIVE_OPTIMIZATION_THRESHOLD:

0 commit comments

Comments
 (0)