Skip to content

Commit 75db566

Browse files
RattataKingkuhar
andauthored
[Tuner] Expand tuning candidate logging feature (#2666)
This PR expands the logging function to support data analysis of the candidate search space and `candidate_ordering` performance evaluation. Tuner now records candidate info such as tuning knob configurations, position in benchmark queue, final result ranking, etc. The collected tuning records will be exported to a CSV file under the tuning base folder. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 9726c06 commit 75db566

File tree

4 files changed

+293
-63
lines changed

4 files changed

+293
-63
lines changed

sharktuner/dispatch_tuner/dispatch_tuner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,11 @@ def main() -> None:
159159
print(path_config.run_log.resolve())
160160
print("Check the summary in:")
161161
print(summary_log_file.resolve())
162+
163+
output_csv_name = f"{args.dispatch_file.stem}_candidate_analysis.csv"
164+
csv_path = Path(path_config.base_dir) / output_csv_name
165+
166+
libtuner.candidate_ordering.export_record_to_csv(
167+
dispatch_tuner.tuning_records, csv_path
168+
)
169+
print(f"Wrote tuning records CSV: {csv_path}")

sharktuner/sharktuner/candidate_ordering.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from enum import Enum
2-
from typing import Optional, Callable
31
import random
42
import logging
3+
import csv
4+
from typing import Optional, Any
5+
from dataclasses import dataclass
6+
from pathlib import Path
7+
from enum import Enum
8+
from typing import Optional, Callable
59

610
from iree.compiler.dialects import iree_gpu # type: ignore
711

@@ -104,3 +108,80 @@ def reorder_assignments(
104108
return indices
105109
case _:
106110
assert False
111+
112+
113+
@dataclass
114+
class TuningRecord:
115+
"""
116+
Records a candidate's knob configuration and tuning results.
117+
118+
Used to analyze the candidate search space and to evaluate the
119+
effectiveness of candidate ordering heuristics.
120+
"""
121+
122+
gen_id: int # Original index from candidate generation.
123+
candidate_id: int # Index in candidate_trackers after reordering.
124+
knob: Optional[common.KnobAssignment] = None
125+
to_compile: bool = False
126+
compile_status: bool = False
127+
to_benchmark: bool = False
128+
benchmark_device_id: Optional[str] = None
129+
benchmark_queue_position: Optional[int] = None
130+
benchmark_status: bool = False
131+
baseline_benchmark_time_us: Optional[float] = None
132+
benchmark_time_us: Optional[float] = None
133+
benchmark_speedup: Optional[float] = None
134+
benchmark_rank_order: Optional[int] = None
135+
136+
137+
def build_tuning_records_from_order(
138+
knobs: list[Optional[common.KnobAssignment]], sorted_order: list[int]
139+
) -> list[TuningRecord]:
140+
tuning_records: list[TuningRecord] = []
141+
# candidate_id = 0 is the baseline and is not included in tuning_records.
142+
for sorted_position, original_gen_index in enumerate(sorted_order, start=1):
143+
tr = TuningRecord(
144+
gen_id=original_gen_index,
145+
candidate_id=sorted_position,
146+
knob=knobs[original_gen_index],
147+
)
148+
tuning_records.append(tr)
149+
150+
return tuning_records
151+
152+
153+
def flatten_records(
154+
tuning_records: list[TuningRecord],
155+
) -> list[dict[str, Any]]:
156+
"""
157+
Flatten a list of `TuningRecord` objects into CSV headers and rows.
158+
159+
- Each record becomes one CSV row.
160+
- Top-level attributes (e.g., `gen_id`, `benchmark_time_us`) appear as individual columns.
161+
- Nested objects (e.g., `knob`) are flattened into columns like `knob.M`, `knob.tile_m`.
162+
"""
163+
rows = []
164+
for tuning_record in tuning_records:
165+
row = {}
166+
for attr, val in vars(tuning_record).items():
167+
if isinstance(val, common.KnobAssignment):
168+
knob_dict = val.get_knobs()
169+
for k, v in knob_dict.items():
170+
row[f"{attr}_{k}"] = v
171+
else:
172+
row[attr] = val
173+
rows.append(row)
174+
175+
return rows
176+
177+
178+
def export_record_to_csv(tuning_records: list[TuningRecord], dest_file: Path) -> None:
179+
assert tuning_records
180+
181+
rows = flatten_records(tuning_records)
182+
headers = list(rows[0].keys())
183+
184+
with open(dest_file, "w", newline="", encoding="utf-8") as f:
185+
writer = csv.DictWriter(f, fieldnames=headers)
186+
writer.writeheader()
187+
writer.writerows(rows)

sharktuner/sharktuner/libtuner.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(self, tuner_context: common.TunerContext):
125125
self.tuner_context = tuner_context
126126
self.candidate_trackers: list[CandidateTracker] = []
127127
self.target_info: Optional[iree_gpu.TargetInfo] = None
128+
self.tuning_records: list[candidate_ordering.TuningRecord] = []
128129

129130
@abstractmethod
130131
def get_iree_compile_flags(self) -> list[str]:
@@ -845,6 +846,10 @@ def generate_candidate_specs(
845846
# Total number of configs = candidates generated + baseline.
846847
assert len(config_specs) == len(solutions) + 1
847848

849+
tuning_client.tuning_records = (
850+
candidate_ordering.build_tuning_records_from_order(knobs, sorted_order)
851+
)
852+
848853
knob_assignments = [dispatch_tuner.get_knob_assignment(s) for s in solutions]
849854
logging.debug("candidate_gen.py ends")
850855
handle_error(
@@ -1193,6 +1198,7 @@ def compile(
11931198
# Set the source and output file paths for compilation of each candidate.
11941199
path_config.compiled_dir.mkdir(parents=True, exist_ok=True)
11951200
for i in candidates:
1201+
tuning_client.tuning_records[i].to_compile = True
11961202
vmfb_file_name = path_config.get_candidate_vmfb_filename(
11971203
tuning_client.candidate_trackers[i].candidate_id
11981204
)
@@ -1231,6 +1237,7 @@ def compile(
12311237
# Remove duplicate vmfbs from the candidate list.
12321238
compiled_candidate_hashes = []
12331239
for candidate_id in compiled_candidates:
1240+
tuning_client.tuning_records[candidate_id].compile_status = True
12341241
candidate_vmfb = tuning_client.candidate_trackers[
12351242
candidate_id
12361243
].compiled_vmfb_path
@@ -1283,6 +1290,9 @@ def benchmark(
12831290
f"Smart candidate benchmark timeout is set to {subprocess_timeout_reference:.2f}s"
12841291
)
12851292
candidate_indices = [i for i in compiled_candidates if i != 0]
1293+
for i, idx in enumerate(candidate_indices, start=1):
1294+
tuning_client.tuning_records[idx].benchmark_queue_position = i
1295+
tuning_client.tuning_records[idx].to_benchmark = True
12861296

12871297
candidate_results = benchmark_candidates(
12881298
candidate_indices=candidate_indices,
@@ -1292,6 +1302,15 @@ def benchmark(
12921302
benchmark_time=benchmark_time, # Only candidate benchmark has time limit.
12931303
)
12941304

1305+
for res in candidate_results:
1306+
c_id = res.candidate_id
1307+
res_time = res.time
1308+
tuning_client.tuning_records[c_id].benchmark_device_id = res.device_id
1309+
if res_time == math.inf:
1310+
continue
1311+
tuning_client.tuning_records[c_id].benchmark_status = True
1312+
tuning_client.tuning_records[c_id].benchmark_time_us = round(res_time, 2)
1313+
12951314
second_baseline_result, _ = benchmark_baseline(
12961315
devices=args.devices,
12971316
tuning_client=tuning_client,
@@ -1315,6 +1334,18 @@ def benchmark(
13151334
candidate_results,
13161335
prune_slow_candidates=tuning_client.should_prune_slower_candidates(),
13171336
)
1337+
1338+
# Best candidate gets rank 1.
1339+
for i, handler_res in enumerate(all_candidates_with_speedup, start=1):
1340+
benchmark_res, speedup = handler_res
1341+
cid, _, device_id = benchmark_res
1342+
baseline_res = baseline_handler.get_average_result_us(device_id)
1343+
tuning_client.tuning_records[cid].baseline_benchmark_time_us = (
1344+
round(baseline_res, 2) if baseline_res else None
1345+
)
1346+
tuning_client.tuning_records[cid].benchmark_speedup = round(speedup, 5)
1347+
tuning_client.tuning_records[cid].benchmark_rank_order = i
1348+
13181349
top_candidates_with_speedup = (
13191350
all_candidates_with_speedup[:num_candidates]
13201351
if num_candidates

0 commit comments

Comments
 (0)