@@ -380,6 +380,9 @@ def determine_best_candidate(
380380 console .rule ()
381381 candidates = deque (candidates )
382382 refinement_done = False
383+ refinements_added = False
384+ line_profiler_done = False
385+ line_profiler_added = False
383386 future_all_refinements : list [concurrent .futures .Future ] = []
384387 ast_code_to_id = {}
385388 valid_optimizations = []
@@ -400,19 +403,41 @@ def determine_best_candidate(
400403 if self .experiment_id
401404 else None ,
402405 )
403- try :
404- candidate_index = 0
405- original_len = len (candidates )
406- while candidates :
406+ candidate_index = 0
407+ original_len = len (candidates )
408+ while True :
409+ try :
410+ if len (candidates )> 0 :
411+ candidate = candidates .popleft ()
412+ else :
413+ if not line_profiler_done :
414+ logger .debug ("all candidates processed, await candidates from line profiler" )
415+ concurrent .futures .wait ([future_line_profile_results ])
416+ line_profile_results = future_line_profile_results .result ()
417+ candidates .extend (line_profile_results )
418+ original_len += len (line_profile_results )
419+ logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len } " )
420+ line_profiler_done = True
421+ continue
422+ if line_profiler_done and not refinement_done :
423+ concurrent .futures .wait (future_all_refinements )
424+ refinement_response = []
425+ for future_refinement in future_all_refinements :
426+ possible_refinement = future_refinement .result ()
427+ if len (possible_refinement ) > 0 : # if the api returns a valid response
428+ refinement_response .append (possible_refinement [0 ])
429+ candidates .extend (refinement_response )
430+ original_len += len (refinement_response )
431+ logger .info (
432+ f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
433+ )
434+ refinement_done = True
435+ continue
436+ if line_profiler_done and refinement_done :
437+ logger .debug ("everything done, exiting" )
438+ break
439+
407440 candidate_index += 1
408- line_profiler_done = True if future_line_profile_results is None else future_line_profile_results .done ()
409- if line_profiler_done and (future_line_profile_results is not None ):
410- line_profile_results = future_line_profile_results .result ()
411- candidates .extend (line_profile_results )
412- original_len += len (line_profile_results )
413- logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len } " )
414- future_line_profile_results = None
415- candidate = candidates .popleft ()
416441 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
417442 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .sqlite" )).unlink (missing_ok = True )
418443 logger .info (f"Optimization candidate { candidate_index } /{ original_len } :" )
@@ -474,7 +499,6 @@ def determine_best_candidate(
474499 file_path_to_helper_classes = file_path_to_helper_classes ,
475500 )
476501 console .rule ()
477-
478502 if not is_successful (run_results ):
479503 optimized_runtimes [candidate .optimization_id ] = None
480504 is_correct [candidate .optimization_id ] = False
@@ -528,7 +552,6 @@ def determine_best_candidate(
528552 optimized_runtime_ns = candidate_replay_runtime ,
529553 )
530554 benchmark_tree .add (f"{ benchmark_key } : { replay_perf_gain [benchmark_key ] * 100 :.1f} %" )
531-
532555 best_optimization = BestOptimization (
533556 candidate = candidate ,
534557 helper_functions = code_context .helper_functions ,
@@ -571,38 +594,12 @@ def determine_best_candidate(
571594 self .write_code_and_helpers (
572595 self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
573596 )
574-
575- if (
576- (not len (candidates )) and (not line_profiler_done )
577- ): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
578- concurrent .futures .wait ([future_line_profile_results ])
579- line_profile_results = future_line_profile_results .result ()
580- candidates .extend (line_profile_results )
581- original_len += len (line_profile_results )
582- logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len } " )
583- future_line_profile_results = None
584- # all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
585- if (not len (candidates )) and line_profiler_done and not refinement_done :
586- # waiting just in case not all calls are finished, nothing else to do
587- concurrent .futures .wait (future_all_refinements )
588- refinement_response = []
589- for future_refinement in future_all_refinements :
590- possible_refinement = future_refinement .result ()
591- if len (possible_refinement ) > 0 : # if the api returns a valid response
592- refinement_response .append (possible_refinement [0 ])
593- candidates .extend (refinement_response )
594- original_len += len (refinement_response )
595- logger .info (
596- f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
597- )
598- refinement_done = True
599- except KeyboardInterrupt as e :
600- self .write_code_and_helpers (
601- self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
602- )
603- logger .exception (f"Optimization interrupted: { e } " )
604- raise
605-
597+ except KeyboardInterrupt as e :
598+ self .write_code_and_helpers (
599+ self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
600+ )
601+ logger .exception (f"Optimization interrupted: { e } " )
602+ raise
606603 if not valid_optimizations :
607604 return None
608605 # need to figure out the best candidate here before we return best_optimization
0 commit comments