Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ def determine_best_candidate(
)
console.rule()
candidates = deque(candidates)
refinement_done = False
future_all_refinements: list[concurrent.futures.Future] = []
# Start a new thread for AI service request, start loop in main thread
# check if aiservice request is complete, when it is complete, append result to the candidates list
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
future_line_profile_results = executor.submit(
ai_service_client.optimize_python_code_line_profiler,
Expand Down Expand Up @@ -515,6 +515,19 @@ def determine_best_candidate(
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
)
self.valid_optimizations.append(best_optimization)
# queue corresponding refined optimization for best optimization
future_all_refinements.append(
self.refine_optimizations(
valid_optimizations=[best_optimization],
original_code_baseline=original_code_baseline,
code_context=code_context,
trace_id=self.function_trace_id[:-4] + exp_type
if self.experiment_id
else self.function_trace_id,
ai_service_client=ai_service_client,
executor=executor,
)
)
else:
tree.add(
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
Expand Down Expand Up @@ -543,26 +556,16 @@ def determine_best_candidate(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
future_line_profile_results = None

if len(candidates) == 0 and len(self.valid_optimizations) > 0 and not refinement_done:
# TODO: Instead of doing it all at once at the end, do it one by one as the optimizations
# are found. This way we can hide the time waiting for the LLM results.
trace_id = self.function_trace_id
if trace_id.endswith(("EXP0", "EXP1")):
trace_id = trace_id[:-4] + exp_type
# refinement_response is a dataclass with optimization_id, code and explanation
refinement_response = self.refine_optimizations(
valid_optimizations=self.valid_optimizations,
original_code_baseline=original_code_baseline,
code_context=code_context,
trace_id=trace_id,
ai_service_client=ai_service_client,
executor=executor,
)
# all original candidates and lp andidates processed
if (not len(candidates)) and line_profiler_done:
# waiting just in case not all calls are finished
concurrent.futures.wait(future_all_refinements)
refinement_response = [
future_refinement.result() for future_refinement in future_all_refinements
]
candidates.extend(refinement_response)
print("Added candidates from refinement")
original_len += len(refinement_response)
refinement_done = True
except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
Expand Down Expand Up @@ -605,7 +608,7 @@ def refine_optimizations(
trace_id: str,
ai_service_client: AiServiceClient,
executor: concurrent.futures.ThreadPoolExecutor,
) -> list[OptimizedCandidate]:
) -> concurrent.futures.Future:
request = [
AIServiceRefinerRequest(
optimization_id=opt.candidate.optimization_id,
Expand All @@ -621,10 +624,8 @@ def refine_optimizations(
optimized_line_profiler_results=opt.line_profiler_test_results["str_out"],
)
for opt in valid_optimizations
] # TODO: multiple workers for this?
future_refinement_results = executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
concurrent.futures.wait([future_refinement_results])
return future_refinement_results.result()
]
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)

def log_successful_optimization(
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
Expand Down
Loading