Skip to content

Commit 1d5c2f9

Browse files
committed
quickfixes
1 parent 83ee9c9 commit 1d5c2f9

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def determine_best_candidate(
375375
)
376376
console.rule()
377377
candidates = deque(candidates)
378+
refinement_done = False
378379
future_all_refinements: list[concurrent.futures.Future] = []
379380
# Start a new thread for AI service request, start loop in main thread
380381
# check if aiservice request is complete, when it is complete, append result to the candidates list
@@ -516,18 +517,19 @@ def determine_best_candidate(
516517
)
517518
self.valid_optimizations.append(best_optimization)
518519
# queue corresponding refined optimization for best optimization
519-
future_all_refinements.append(
520-
self.refine_optimizations(
521-
valid_optimizations=[best_optimization],
522-
original_code_baseline=original_code_baseline,
523-
code_context=code_context,
524-
trace_id=self.function_trace_id[:-4] + exp_type
525-
if self.experiment_id
526-
else self.function_trace_id,
527-
ai_service_client=ai_service_client,
528-
executor=executor,
520+
if not candidate.optimization_id.endswith("refi"):
521+
future_all_refinements.append(
522+
self.refine_optimizations(
523+
valid_optimizations=[best_optimization],
524+
original_code_baseline=original_code_baseline,
525+
code_context=code_context,
526+
trace_id=self.function_trace_id[:-4] + exp_type
527+
if self.experiment_id
528+
else self.function_trace_id,
529+
ai_service_client=ai_service_client,
530+
executor=executor,
531+
)
529532
)
530-
)
531533
else:
532534
tree.add(
533535
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
@@ -545,9 +547,9 @@ def determine_best_candidate(
545547
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
546548
)
547549

548-
if (not len(candidates)) and (
549-
not line_profiler_done
550-
): # all original candidates processed but lp results haven't been processed
550+
if (
551+
(not len(candidates)) and (not line_profiler_done)
552+
): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
551553
concurrent.futures.wait([future_line_profile_results])
552554
line_profile_results = future_line_profile_results.result()
553555
candidates.extend(line_profile_results)
@@ -556,16 +558,19 @@ def determine_best_candidate(
556558
f"Added results from line profiler to candidates, total candidates now: {original_len}"
557559
)
558560
future_line_profile_results = None
559-
# all original candidates and lp andidates processed
560-
if (not len(candidates)) and line_profiler_done:
561-
# waiting just in case not all calls are finished
561+
# all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
562+
if (not len(candidates)) and line_profiler_done and not refinement_done:
563+
# waiting just in case not all calls are finished, nothing else to do
562564
concurrent.futures.wait(future_all_refinements)
563-
refinement_response = [
564-
future_refinement.result() for future_refinement in future_all_refinements
565-
]
565+
refinement_response = []
566+
for future_refinement in future_all_refinements:
567+
possible_refinement = future_refinement.result()
568+
if len(possible_refinement) > 0: # if the api returns a valid response
569+
refinement_response.append(possible_refinement[0])
566570
candidates.extend(refinement_response)
567-
print("Added candidates from refinement")
571+
logger.info(f"Added {len(refinement_response)} candidates from refinement")
568572
original_len += len(refinement_response)
573+
refinement_done = True
569574
except KeyboardInterrupt as e:
570575
self.write_code_and_helpers(
571576
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path

0 commit comments

Comments
 (0)