Skip to content
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,10 @@ def determine_best_candidate(
console.rule()
candidates = deque(candidates)
refinement_done = False
future_all_refinements: list[concurrent.futures.Future] = []
# Start a new thread for AI service request, start loop in main thread
# check if aiservice request is complete, when it is complete, append result to the candidates list
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
future_line_profile_results = executor.submit(
ai_service_client.optimize_python_code_line_profiler,
Expand Down Expand Up @@ -515,6 +516,20 @@ def determine_best_candidate(
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
)
self.valid_optimizations.append(best_optimization)
# queue corresponding refined optimization for best optimization
if not candidate.optimization_id.endswith("refi"):
future_all_refinements.append(
self.refine_optimizations(
valid_optimizations=[best_optimization],
original_code_baseline=original_code_baseline,
code_context=code_context,
trace_id=self.function_trace_id[:-4] + exp_type
if self.experiment_id
else self.function_trace_id,
ai_service_client=ai_service_client,
executor=executor,
)
)
else:
tree.add(
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
Expand All @@ -532,9 +547,9 @@ def determine_best_candidate(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)

if (not len(candidates)) and (
not line_profiler_done
): # all original candidates processed but lp results haven't been processed
if (
(not len(candidates)) and (not line_profiler_done)
): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
Expand All @@ -543,24 +558,17 @@ def determine_best_candidate(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
future_line_profile_results = None

if len(candidates) == 0 and len(self.valid_optimizations) > 0 and not refinement_done:
# TODO: Instead of doing it all at once at the end, do it one by one as the optimizations
# are found. This way we can hide the time waiting for the LLM results.
trace_id = self.function_trace_id
if trace_id.endswith(("EXP0", "EXP1")):
trace_id = trace_id[:-4] + exp_type
# refinement_response is a dataclass with optimization_id, code and explanation
refinement_response = self.refine_optimizations(
valid_optimizations=self.valid_optimizations,
original_code_baseline=original_code_baseline,
code_context=code_context,
trace_id=trace_id,
ai_service_client=ai_service_client,
executor=executor,
)
# all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
if (not len(candidates)) and line_profiler_done and not refinement_done:
# waiting just in case not all calls are finished, nothing else to do
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
print("Added candidates from refinement")
logger.info(f"Added {len(refinement_response)} candidates from refinement")
original_len += len(refinement_response)
refinement_done = True
except KeyboardInterrupt as e:
Expand Down Expand Up @@ -605,7 +613,7 @@ def refine_optimizations(
trace_id: str,
ai_service_client: AiServiceClient,
executor: concurrent.futures.ThreadPoolExecutor,
) -> list[OptimizedCandidate]:
) -> concurrent.futures.Future:
request = [
AIServiceRefinerRequest(
optimization_id=opt.candidate.optimization_id,
Expand All @@ -621,10 +629,8 @@ def refine_optimizations(
optimized_line_profiler_results=opt.line_profiler_test_results["str_out"],
)
for opt in valid_optimizations
] # TODO: multiple workers for this?
future_refinement_results = executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
concurrent.futures.wait([future_refinement_results])
return future_refinement_results.result()
]
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)

def log_successful_optimization(
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
Expand Down
Loading