Skip to content

Commit 65d14d8

Browse files
committed
Add test for flatten_class
1 parent 38663f6 commit 65d14d8

File tree

3 files changed

+171
-25
lines changed

3 files changed

+171
-25
lines changed

sharktuner/dispatch_tuner/dispatch_tuner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def main() -> None:
161161
print(summary_log_file.resolve())
162162

163163
output_csv_name = f"tuning_{args.dispatch_file.stem}.csv"
164-
csv_path = libtuner.candidate_ordering.export_record_to_csv(
165-
dispatch_tuner.tuning_records, path_config.base_dir, output_csv_name
164+
csv_path = Path(path_config.base_dir / output_csv_name)
165+
libtuner.candidate_ordering.export_record_to_csv(
166+
dispatch_tuner.tuning_records, csv_path
166167
)
167168
print(f"Wrote tuning records CSV: {csv_path}")

sharktuner/sharktuner/candidate_ordering.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
import csv
5-
from typing import Optional
5+
from typing import Optional, Any
66
from dataclasses import dataclass
77
from pathlib import Path
88
from enum import Enum
@@ -147,20 +147,16 @@ def init_tuning_records(
147147
return tuning_records
148148

149149

150-
def export_record_to_csv(
151-
tuning_records: list[TuningRecord], dest_dir: Path, filename: str = "export.csv"
152-
) -> Optional[Path]:
150+
def flatten_records(
151+
tuning_records: list[TuningRecord],
152+
) -> tuple[list[str], list[dict[str, Any]]]:
153153
"""
154-
Exports a list of `TuningRecord` objects to a CSV file.
154+
Flatten a list of `TuningRecord` objects to CSV headers and rows
155155
156156
- Each record becomes one CSV row.
157157
- Top-level attributes (e.g., `gen_id`, `benchmark_time_us`) are written as individual columns.
158158
- Nested object (i.e., `knob`) is flattened using dot notation: knob.tile_m, knob.intrinsic_mn
159-
160159
"""
161-
if not tuning_records:
162-
return None
163-
164160
rows = []
165161
headers = []
166162

@@ -179,14 +175,20 @@ def export_record_to_csv(
179175
continue
180176
else:
181177
row[k] = v
182-
if k not in headers:
178+
if k not in headers and k != "knob":
183179
headers.append(k)
184180
rows.append(row)
185181

186-
path = Path(os.path.join(dest_dir, filename))
187-
with open(path, "w", newline="", encoding="utf-8") as f:
182+
return headers, rows
183+
184+
185+
def export_record_to_csv(tuning_records: list[TuningRecord], dest_file: Path) -> None:
186+
if not tuning_records:
187+
return None
188+
189+
headers, rows = flatten_records(tuning_records)
190+
191+
with open(dest_file, "w", newline="", encoding="utf-8") as f:
188192
writer = csv.DictWriter(f, fieldnames=headers)
189193
writer.writeheader()
190194
writer.writerows(rows)
191-
192-
return path

sharktuner/tests/candidate_ordering_test.py

Lines changed: 152 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,12 @@ def test_reorder_assignments(
162162
def test_init_tuning_records(
163163
sample_knobs: list[Optional[common.KnobAssignment]],
164164
) -> None:
165-
sorted_order = [2, 0, 1]
166-
tuning_records = candidate_ordering.init_tuning_records(sample_knobs, sorted_order)
167-
168-
expected: list[candidate_ordering.TuningRecord] = [
169-
candidate_ordering.TuningRecord(
170-
gen_id=0, candidate_id=0, to_compile=True, to_benchmark=True
171-
)
172-
]
165+
tr0 = candidate_ordering.TuningRecord(
166+
gen_id=0,
167+
candidate_id=0,
168+
to_compile=True,
169+
to_benchmark=True,
170+
)
173171
tr1 = candidate_ordering.TuningRecord(
174172
gen_id=2,
175173
candidate_id=1,
@@ -185,6 +183,151 @@ def test_init_tuning_records(
185183
candidate_id=3,
186184
knob=sample_knobs[1],
187185
)
188-
expected += [tr1, tr2, tr3]
186+
sorted_order = [2, 0, 1]
187+
tuning_records = candidate_ordering.init_tuning_records(sample_knobs, sorted_order)
188+
189+
expected = [tr0, tr1, tr2, tr3]
189190

190191
assert tuning_records == expected
192+
193+
194+
def test_flatten_records(
195+
sample_knobs: list[Optional[common.KnobAssignment]],
196+
):
197+
tr0 = candidate_ordering.TuningRecord(
198+
gen_id=0,
199+
candidate_id=0,
200+
to_compile=True,
201+
to_benchmark=True,
202+
)
203+
tr1 = candidate_ordering.TuningRecord(
204+
gen_id=2,
205+
candidate_id=1,
206+
knob=sample_knobs[2],
207+
to_compile=True,
208+
benchmark_device_id="hip://2",
209+
benchmark_queue_position=1,
210+
baseline_benchmark_time_us=123.4,
211+
benchmark_speedup=1.5,
212+
)
213+
tr2 = candidate_ordering.TuningRecord(
214+
gen_id=1,
215+
candidate_id=2,
216+
knob=sample_knobs[1],
217+
to_benchmark=True,
218+
benchmark_time_us=153.56,
219+
)
220+
sample_tuning_records = [tr0, tr1, tr2]
221+
222+
headers, rows = candidate_ordering.flatten_records(sample_tuning_records)
223+
224+
expected_headers = [
225+
"gen_id",
226+
"candidate_id",
227+
"to_compile",
228+
"compile_status",
229+
"to_benchmark",
230+
"benchmark_device_id",
231+
"benchmark_queue_position",
232+
"benchmark_status",
233+
"baseline_benchmark_time_us",
234+
"benchmark_time_us",
235+
"benchmark_speedup",
236+
"benchmark_rank_order",
237+
"knob.M",
238+
"knob.N",
239+
"knob.K",
240+
"knob.tile_m",
241+
"knob.tile_n",
242+
"knob.tile_k",
243+
"knob.wg_x",
244+
"knob.wg_y",
245+
"knob.wg_z",
246+
"knob.subgroup_m_cnt",
247+
"knob.subgroup_n_cnt",
248+
"knob.intrinsic_mn",
249+
"knob.intrinsic_k",
250+
"knob.subgroup_m",
251+
"knob.subgroup_n",
252+
"knob.subgroup_k",
253+
]
254+
assert headers == expected_headers
255+
256+
expected_rows = [
257+
{
258+
"baseline_benchmark_time_us": None,
259+
"benchmark_device_id": None,
260+
"benchmark_queue_position": None,
261+
"benchmark_rank_order": None,
262+
"benchmark_speedup": None,
263+
"benchmark_status": False,
264+
"benchmark_time_us": None,
265+
"candidate_id": 0,
266+
"compile_status": False,
267+
"gen_id": 0,
268+
"knob": None,
269+
"to_benchmark": True,
270+
"to_compile": True,
271+
},
272+
{
273+
"baseline_benchmark_time_us": 123.4,
274+
"benchmark_device_id": "hip://2",
275+
"benchmark_queue_position": 1,
276+
"benchmark_rank_order": None,
277+
"benchmark_speedup": 1.5,
278+
"benchmark_status": False,
279+
"benchmark_time_us": None,
280+
"candidate_id": 1,
281+
"compile_status": False,
282+
"gen_id": 2,
283+
"knob.K": 1280,
284+
"knob.M": 2048,
285+
"knob.N": 10240,
286+
"knob.intrinsic_k": 16,
287+
"knob.intrinsic_mn": 16,
288+
"knob.subgroup_k": 0,
289+
"knob.subgroup_m": 0,
290+
"knob.subgroup_m_cnt": 2,
291+
"knob.subgroup_n": 0,
292+
"knob.subgroup_n_cnt": 4,
293+
"knob.tile_k": 16,
294+
"knob.tile_m": 64,
295+
"knob.tile_n": 256,
296+
"knob.wg_x": 256,
297+
"knob.wg_y": 2,
298+
"knob.wg_z": 1,
299+
"to_benchmark": False,
300+
"to_compile": True,
301+
},
302+
{
303+
"baseline_benchmark_time_us": None,
304+
"benchmark_device_id": None,
305+
"benchmark_queue_position": None,
306+
"benchmark_rank_order": None,
307+
"benchmark_speedup": None,
308+
"benchmark_status": False,
309+
"benchmark_time_us": 153.56,
310+
"candidate_id": 2,
311+
"compile_status": False,
312+
"gen_id": 1,
313+
"knob.K": 1280,
314+
"knob.M": 2048,
315+
"knob.N": 10240,
316+
"knob.intrinsic_k": 16,
317+
"knob.intrinsic_mn": 16,
318+
"knob.subgroup_k": 0,
319+
"knob.subgroup_m": 0,
320+
"knob.subgroup_m_cnt": 1,
321+
"knob.subgroup_n": 0,
322+
"knob.subgroup_n_cnt": 5,
323+
"knob.tile_k": 80,
324+
"knob.tile_m": 64,
325+
"knob.tile_n": 320,
326+
"knob.wg_x": 320,
327+
"knob.wg_y": 1,
328+
"knob.wg_z": 1,
329+
"to_benchmark": True,
330+
"to_compile": False,
331+
},
332+
]
333+
assert rows == expected_rows

0 commit comments

Comments
 (0)