@@ -380,6 +380,7 @@ def determine_best_candidate(
380
380
console .rule ()
381
381
candidates = deque (candidates )
382
382
refinement_done = False
383
+ line_profiler_done = False
383
384
future_all_refinements : list [concurrent .futures .Future ] = []
384
385
ast_code_to_id = {}
385
386
valid_optimizations = []
@@ -400,19 +401,45 @@ def determine_best_candidate(
400
401
if self .experiment_id
401
402
else None ,
402
403
)
403
- try :
404
- candidate_index = 0
405
- original_len = len (candidates )
406
- while candidates :
404
+ candidate_index = 0
405
+ original_len = len (candidates )
406
+ # TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
407
+ # TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
408
+ while True :
409
+ try :
410
+ if len (candidates ) > 0 :
411
+ candidate = candidates .popleft ()
412
+ else :
413
+ if not line_profiler_done :
414
+ logger .debug ("all candidates processed, await candidates from line profiler" )
415
+ concurrent .futures .wait ([future_line_profile_results ])
416
+ line_profile_results = future_line_profile_results .result ()
417
+ candidates .extend (line_profile_results )
418
+ original_len += len (line_profile_results )
419
+ logger .info (
420
+ f"Added results from line profiler to candidates, total candidates now: { original_len } "
421
+ )
422
+ line_profiler_done = True
423
+ continue
424
+ if line_profiler_done and not refinement_done :
425
+ concurrent .futures .wait (future_all_refinements )
426
+ refinement_response = []
427
+ for future_refinement in future_all_refinements :
428
+ possible_refinement = future_refinement .result ()
429
+ if len (possible_refinement ) > 0 : # if the api returns a valid response
430
+ refinement_response .append (possible_refinement [0 ])
431
+ candidates .extend (refinement_response )
432
+ original_len += len (refinement_response )
433
+ logger .info (
434
+ f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
435
+ )
436
+ refinement_done = True
437
+ continue
438
+ if line_profiler_done and refinement_done :
439
+ logger .debug ("everything done, exiting" )
440
+ break
441
+
407
442
candidate_index += 1
408
- line_profiler_done = True if future_line_profile_results is None else future_line_profile_results .done ()
409
- if line_profiler_done and (future_line_profile_results is not None ):
410
- line_profile_results = future_line_profile_results .result ()
411
- candidates .extend (line_profile_results )
412
- original_len += len (line_profile_results )
413
- logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len } " )
414
- future_line_profile_results = None
415
- candidate = candidates .popleft ()
416
443
get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
417
444
get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .sqlite" )).unlink (missing_ok = True )
418
445
logger .info (f"Optimization candidate { candidate_index } /{ original_len } :" )
@@ -474,7 +501,6 @@ def determine_best_candidate(
474
501
file_path_to_helper_classes = file_path_to_helper_classes ,
475
502
)
476
503
console .rule ()
477
-
478
504
if not is_successful (run_results ):
479
505
optimized_runtimes [candidate .optimization_id ] = None
480
506
is_correct [candidate .optimization_id ] = False
@@ -528,7 +554,6 @@ def determine_best_candidate(
528
554
optimized_runtime_ns = candidate_replay_runtime ,
529
555
)
530
556
benchmark_tree .add (f"{ benchmark_key } : { replay_perf_gain [benchmark_key ] * 100 :.1f} %" )
531
-
532
557
best_optimization = BestOptimization (
533
558
candidate = candidate ,
534
559
helper_functions = code_context .helper_functions ,
@@ -571,38 +596,12 @@ def determine_best_candidate(
571
596
self .write_code_and_helpers (
572
597
self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
573
598
)
574
-
575
- if (
576
- (not len (candidates )) and (not line_profiler_done )
577
- ): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
578
- concurrent .futures .wait ([future_line_profile_results ])
579
- line_profile_results = future_line_profile_results .result ()
580
- candidates .extend (line_profile_results )
581
- original_len += len (line_profile_results )
582
- logger .info (f"Added results from line profiler to candidates, total candidates now: { original_len } " )
583
- future_line_profile_results = None
584
- # all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
585
- if (not len (candidates )) and line_profiler_done and not refinement_done :
586
- # waiting just in case not all calls are finished, nothing else to do
587
- concurrent .futures .wait (future_all_refinements )
588
- refinement_response = []
589
- for future_refinement in future_all_refinements :
590
- possible_refinement = future_refinement .result ()
591
- if len (possible_refinement ) > 0 : # if the api returns a valid response
592
- refinement_response .append (possible_refinement [0 ])
593
- candidates .extend (refinement_response )
594
- original_len += len (refinement_response )
595
- logger .info (
596
- f"Added { len (refinement_response )} candidates from refinement, total candidates now: { original_len } "
597
- )
598
- refinement_done = True
599
- except KeyboardInterrupt as e :
600
- self .write_code_and_helpers (
601
- self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
602
- )
603
- logger .exception (f"Optimization interrupted: { e } " )
604
- raise
605
-
599
+ except KeyboardInterrupt as e :
600
+ self .write_code_and_helpers (
601
+ self .function_to_optimize_source_code , original_helper_code , self .function_to_optimize .file_path
602
+ )
603
+ logger .exception (f"Optimization interrupted: { e } " )
604
+ raise
606
605
if not valid_optimizations :
607
606
return None
608
607
# need to figure out the best candidate here before we return best_optimization
0 commit comments