Skip to content

Commit 111c7d8

Browse files
committed
Add logging
1 parent 9726c06 commit 111c7d8

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Optional
2+
from dataclasses import dataclass
3+
4+
from . import common, candidate_tuning_records
5+
6+
7+
@dataclass
8+
class TuningRecord:
9+
gen_id: int
10+
candidate_id: int
11+
knob: Optional[common.KnobAssignment] = None
12+
to_compile: bool = False
13+
compile_status: bool = False
14+
to_benchmark: bool = False
15+
benchmark_device_id: Optional[str] = None
16+
benchmark_queue_position: Optional[int] = None
17+
benchmark_status: bool = False
18+
baseline_benchmark_time_us: Optional[float] = None
19+
benchmark_time_us: Optional[float] = None
20+
benchmark_speedup: Optional[float] = None
21+
benchmark_rank_order: Optional[int] = None
22+
23+
def init_tuning_records(knobs: list[Optional[common.KnobAssignment]], sorted_order: list[int]) -> list[TuningRecord]:
24+
tuning_records: list[TuningRecord] = []
25+
tuning_records.append(TuningRecord(gen_id=0, candidate_id=0, to_compile=True, to_benchmark=True))
26+
27+
for can_idx, gen_idx in enumerate(sorted_order, start=1):
28+
tr = TuningRecord(
29+
gen_id=gen_idx,
30+
candidate_id=can_idx,
31+
knob=knobs[gen_idx],
32+
)
33+
tuning_records.append(tr)
34+
35+
return tuning_records

sharktuner/sharktuner/libtuner.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
dispatch_constraints,
5050
dispatch_parser,
5151
candidate_ordering,
52+
candidate_tuning_records,
5253
)
5354

5455

@@ -125,6 +126,7 @@ def __init__(self, tuner_context: common.TunerContext):
125126
self.tuner_context = tuner_context
126127
self.candidate_trackers: list[CandidateTracker] = []
127128
self.target_info: Optional[iree_gpu.TargetInfo] = None
129+
self.tuning_records: list[candidate_tuning_records.TuningRecord] = []
128130

129131
@abstractmethod
130132
def get_iree_compile_flags(self) -> list[str]:
@@ -845,6 +847,10 @@ def generate_candidate_specs(
845847
# Total number of configs = candidates generated + baseline.
846848
assert len(config_specs) == len(solutions) + 1
847849

850+
tuning_client.tuning_records = candidate_tuning_records.init_tuning_records(
851+
knobs, sorted_order
852+
)
853+
848854
knob_assignments = [dispatch_tuner.get_knob_assignment(s) for s in solutions]
849855
logging.debug("candidate_gen.py ends")
850856
handle_error(
@@ -1193,6 +1199,7 @@ def compile(
11931199
# Set the source and output file paths for compilation of each candidate.
11941200
path_config.compiled_dir.mkdir(parents=True, exist_ok=True)
11951201
for i in candidates:
1202+
tuning_client.tuning_records[i].to_compile = True
11961203
vmfb_file_name = path_config.get_candidate_vmfb_filename(
11971204
tuning_client.candidate_trackers[i].candidate_id
11981205
)
@@ -1231,6 +1238,7 @@ def compile(
12311238
# Remove duplicate vmfbs from the candidate list.
12321239
compiled_candidate_hashes = []
12331240
for candidate_id in compiled_candidates:
1241+
tuning_client.tuning_records[candidate_id].compile_status = True
12341242
candidate_vmfb = tuning_client.candidate_trackers[
12351243
candidate_id
12361244
].compiled_vmfb_path
@@ -1268,21 +1276,27 @@ def benchmark(
12681276

12691277
# Benchmarking baselines on each involved device.
12701278
baseline_tracker = tuning_client.candidate_trackers[0]
1279+
tuning_client.tuning_records[0].to_benchmark = True
12711280
first_baseline_result, subprocess_timeout_reference = benchmark_baseline(
12721281
devices=args.devices,
12731282
tuning_client=tuning_client,
12741283
candidate_tracker=baseline_tracker,
12751284
)
12761285
baseline_handler = BaselineResultHandler()
12771286
baseline_handler.add_run(first_baseline_result)
1287+
tuning_client.tuning_records[0].benchmark_status = True
12781288
if not baseline_handler.is_valid():
12791289
logging.warning("Baseline run failed.")
1290+
tuning_client.tuning_records[0].benchmark_status = False
12801291

12811292
if tuning_client.is_auto_iree_benchmark_timeout():
12821293
logging.info(
12831294
f"Smart candidate benchmark timeout is set to {subprocess_timeout_reference:.2f}s"
12841295
)
12851296
candidate_indices = [i for i in compiled_candidates if i != 0]
1297+
for i, idx in enumerate(candidate_indices, start=1):
1298+
tuning_client.tuning_records[idx].benchmark_queue_position = i
1299+
tuning_client.tuning_records[idx].to_benchmark = True
12861300

12871301
candidate_results = benchmark_candidates(
12881302
candidate_indices=candidate_indices,
@@ -1292,6 +1306,17 @@ def benchmark(
12921306
benchmark_time=benchmark_time, # Only candidate benchmark has time limit.
12931307
)
12941308

1309+
for res in candidate_results:
1310+
tuning_client.tuning_records[
1311+
res.candidate_id
1312+
].benchmark_device_id = res.device_id
1313+
if res.time == math.inf:
1314+
continue
1315+
tuning_client.tuning_records[res.candidate_id].benchmark_status = True
1316+
tuning_client.tuning_records[res.candidate_id].benchmark_time_us = round(
1317+
res.time, 2
1318+
)
1319+
12951320
second_baseline_result, _ = benchmark_baseline(
12961321
devices=args.devices,
12971322
tuning_client=tuning_client,
@@ -1315,6 +1340,15 @@ def benchmark(
13151340
candidate_results,
13161341
prune_slow_candidates=tuning_client.should_prune_slower_candidates(),
13171342
)
1343+
if all_candidates_with_speedup:
1344+
for i, handler_res in enumerate(all_candidates_with_speedup, start=1):
1345+
benchmark_res, speedup = handler_res
1346+
cid, _, device_id = benchmark_res
1347+
bas = baseline_handler.get_average_result_us(device_id)
1348+
tuning_client.tuning_records[cid].baseline_benchmark_time_us = round(bas, 2)
1349+
tuning_client.tuning_records[cid].benchmark_speedup = round(speedup, 5)
1350+
tuning_client.tuning_records[cid].benchmark_rank_order = i
1351+
13181352
top_candidates_with_speedup = (
13191353
all_candidates_with_speedup[:num_candidates]
13201354
if num_candidates

0 commit comments

Comments
 (0)