Skip to content

Commit 83ee9c9

Browse files
committed
async refinement calls for better queing
1 parent f681c0f commit 83ee9c9

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,10 @@ def determine_best_candidate(
375375
)
376376
console.rule()
377377
candidates = deque(candidates)
378-
refinement_done = False
378+
future_all_refinements: list[concurrent.futures.Future] = []
379379
# Start a new thread for AI service request, start loop in main thread
380380
# check if aiservice request is complete, when it is complete, append result to the candidates list
381-
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
381+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
382382
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
383383
future_line_profile_results = executor.submit(
384384
ai_service_client.optimize_python_code_line_profiler,
@@ -515,6 +515,19 @@ def determine_best_candidate(
515515
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
516516
)
517517
self.valid_optimizations.append(best_optimization)
518+
# queue corresponding refined optimization for best optimization
519+
future_all_refinements.append(
520+
self.refine_optimizations(
521+
valid_optimizations=[best_optimization],
522+
original_code_baseline=original_code_baseline,
523+
code_context=code_context,
524+
trace_id=self.function_trace_id[:-4] + exp_type
525+
if self.experiment_id
526+
else self.function_trace_id,
527+
ai_service_client=ai_service_client,
528+
executor=executor,
529+
)
530+
)
518531
else:
519532
tree.add(
520533
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
@@ -543,26 +556,16 @@ def determine_best_candidate(
543556
f"Added results from line profiler to candidates, total candidates now: {original_len}"
544557
)
545558
future_line_profile_results = None
546-
547-
if len(candidates) == 0 and len(self.valid_optimizations) > 0 and not refinement_done:
548-
# TODO: Instead of doing it all at once at the end, do it one by one as the optimizations
549-
# are found. This way we can hide the time waiting for the LLM results.
550-
trace_id = self.function_trace_id
551-
if trace_id.endswith(("EXP0", "EXP1")):
552-
trace_id = trace_id[:-4] + exp_type
553-
# refinement_response is a dataclass with optimization_id, code and explanation
554-
refinement_response = self.refine_optimizations(
555-
valid_optimizations=self.valid_optimizations,
556-
original_code_baseline=original_code_baseline,
557-
code_context=code_context,
558-
trace_id=trace_id,
559-
ai_service_client=ai_service_client,
560-
executor=executor,
561-
)
559+
# all original candidates and lp andidates processed
560+
if (not len(candidates)) and line_profiler_done:
561+
# waiting just in case not all calls are finished
562+
concurrent.futures.wait(future_all_refinements)
563+
refinement_response = [
564+
future_refinement.result() for future_refinement in future_all_refinements
565+
]
562566
candidates.extend(refinement_response)
563567
print("Added candidates from refinement")
564568
original_len += len(refinement_response)
565-
refinement_done = True
566569
except KeyboardInterrupt as e:
567570
self.write_code_and_helpers(
568571
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
@@ -605,7 +608,7 @@ def refine_optimizations(
605608
trace_id: str,
606609
ai_service_client: AiServiceClient,
607610
executor: concurrent.futures.ThreadPoolExecutor,
608-
) -> list[OptimizedCandidate]:
611+
) -> concurrent.futures.Future:
609612
request = [
610613
AIServiceRefinerRequest(
611614
optimization_id=opt.candidate.optimization_id,
@@ -621,10 +624,8 @@ def refine_optimizations(
621624
optimized_line_profiler_results=opt.line_profiler_test_results["str_out"],
622625
)
623626
for opt in valid_optimizations
624-
] # TODO: multiple workers for this?
625-
future_refinement_results = executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
626-
concurrent.futures.wait([future_refinement_results])
627-
return future_refinement_results.result()
627+
]
628+
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
628629

629630
def log_successful_optimization(
630631
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str

0 commit comments

Comments
 (0)