4949 dispatch_constraints ,
5050 dispatch_parser ,
5151 candidate_ordering ,
52+ candidate_tuning_records ,
5253)
5354
5455
@@ -125,6 +126,7 @@ def __init__(self, tuner_context: common.TunerContext):
125126 self .tuner_context = tuner_context
126127 self .candidate_trackers : list [CandidateTracker ] = []
127128 self .target_info : Optional [iree_gpu .TargetInfo ] = None
129+ self .tuning_records : list [candidate_tuning_records .TuningRecord ] = []
128130
129131 @abstractmethod
130132 def get_iree_compile_flags (self ) -> list [str ]:
@@ -845,6 +847,10 @@ def generate_candidate_specs(
845847 # Total number of configs = candidates generated + baseline.
846848 assert len (config_specs ) == len (solutions ) + 1
847849
850+ tuning_client .tuning_records = candidate_tuning_records .init_tuning_records (
851+ knobs , sorted_order
852+ )
853+
848854 knob_assignments = [dispatch_tuner .get_knob_assignment (s ) for s in solutions ]
849855 logging .debug ("candidate_gen.py ends" )
850856 handle_error (
@@ -1193,6 +1199,7 @@ def compile(
11931199 # Set the source and output file paths for compilation of each candidate.
11941200 path_config .compiled_dir .mkdir (parents = True , exist_ok = True )
11951201 for i in candidates :
1202+ tuning_client .tuning_records [i ].to_compile = True
11961203 vmfb_file_name = path_config .get_candidate_vmfb_filename (
11971204 tuning_client .candidate_trackers [i ].candidate_id
11981205 )
@@ -1231,6 +1238,7 @@ def compile(
12311238 # Remove duplicate vmfbs from the candidate list.
12321239 compiled_candidate_hashes = []
12331240 for candidate_id in compiled_candidates :
1241+ tuning_client .tuning_records [candidate_id ].compile_status = True
12341242 candidate_vmfb = tuning_client .candidate_trackers [
12351243 candidate_id
12361244 ].compiled_vmfb_path
@@ -1268,21 +1276,27 @@ def benchmark(
12681276
12691277 # Benchmarking baselines on each involved device.
12701278 baseline_tracker = tuning_client .candidate_trackers [0 ]
1279+ tuning_client .tuning_records [0 ].to_benchmark = True
12711280 first_baseline_result , subprocess_timeout_reference = benchmark_baseline (
12721281 devices = args .devices ,
12731282 tuning_client = tuning_client ,
12741283 candidate_tracker = baseline_tracker ,
12751284 )
12761285 baseline_handler = BaselineResultHandler ()
12771286 baseline_handler .add_run (first_baseline_result )
1287+ tuning_client .tuning_records [0 ].benchmark_status = True
12781288 if not baseline_handler .is_valid ():
12791289 logging .warning ("Baseline run failed." )
1290+ tuning_client .tuning_records [0 ].benchmark_status = False
12801291
12811292 if tuning_client .is_auto_iree_benchmark_timeout ():
12821293 logging .info (
12831294 f"Smart candidate benchmark timeout is set to { subprocess_timeout_reference :.2f} s"
12841295 )
12851296 candidate_indices = [i for i in compiled_candidates if i != 0 ]
1297+ for i , idx in enumerate (candidate_indices , start = 1 ):
1298+ tuning_client .tuning_records [idx ].benchmark_queue_position = i
1299+ tuning_client .tuning_records [idx ].to_benchmark = True
12861300
12871301 candidate_results = benchmark_candidates (
12881302 candidate_indices = candidate_indices ,
@@ -1292,6 +1306,17 @@ def benchmark(
12921306 benchmark_time = benchmark_time , # Only candidate benchmark has time limit.
12931307 )
12941308
1309+ for res in candidate_results :
1310+ tuning_client .tuning_records [
1311+ res .candidate_id
1312+ ].benchmark_device_id = res .device_id
1313+ if res .time == math .inf :
1314+ continue
1315+ tuning_client .tuning_records [res .candidate_id ].benchmark_status = True
1316+ tuning_client .tuning_records [res .candidate_id ].benchmark_time_us = round (
1317+ res .time , 2
1318+ )
1319+
12951320 second_baseline_result , _ = benchmark_baseline (
12961321 devices = args .devices ,
12971322 tuning_client = tuning_client ,
@@ -1315,6 +1340,15 @@ def benchmark(
13151340 candidate_results ,
13161341 prune_slow_candidates = tuning_client .should_prune_slower_candidates (),
13171342 )
1343+ if all_candidates_with_speedup :
1344+ for i , handler_res in enumerate (all_candidates_with_speedup , start = 1 ):
1345+ benchmark_res , speedup = handler_res
1346+ cid , _ , device_id = benchmark_res
1347+ bas = baseline_handler .get_average_result_us (device_id )
1348+ tuning_client .tuning_records [cid ].baseline_benchmark_time_us = round (bas , 2 )
1349+ tuning_client .tuning_records [cid ].benchmark_speedup = round (speedup , 5 )
1350+ tuning_client .tuning_records [cid ].benchmark_rank_order = i
1351+
13181352 top_candidates_with_speedup = (
13191353 all_candidates_with_speedup [:num_candidates ]
13201354 if num_candidates
0 commit comments