Skip to content

Commit 6692512

Browse files
committed
Move code to candidate_ordering
1 parent ab94d4e commit 6692512

File tree

4 files changed

+180
-146
lines changed

4 files changed

+180
-146
lines changed

sharktuner/sharktuner/candidate_ordering.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from enum import Enum
2-
from typing import Optional, Callable
31
import random
42
import logging
3+
import os
4+
import csv
5+
from typing import Optional
6+
from dataclasses import dataclass
7+
from pathlib import Path
8+
from enum import Enum
9+
from typing import Optional, Callable
510

611
from iree.compiler.dialects import iree_gpu # type: ignore
712

@@ -104,3 +109,77 @@ def reorder_assignments(
104109
return indices
105110
case _:
106111
assert False
112+
113+
114+
@dataclass
115+
class TuningRecord:
116+
gen_id: int
117+
candidate_id: int
118+
knob: Optional[common.KnobAssignment] = None
119+
to_compile: bool = False
120+
compile_status: bool = False
121+
to_benchmark: bool = False
122+
benchmark_device_id: Optional[str] = None
123+
benchmark_queue_position: Optional[int] = None
124+
benchmark_status: bool = False
125+
baseline_benchmark_time_us: Optional[float] = None
126+
benchmark_time_us: Optional[float] = None
127+
benchmark_speedup: Optional[float] = None
128+
benchmark_rank_order: Optional[int] = None
129+
130+
131+
def init_tuning_records(
132+
knobs: list[Optional[common.KnobAssignment]], sorted_order: list[int]
133+
) -> list[TuningRecord]:
134+
tuning_records: list[TuningRecord] = []
135+
tuning_records.append(
136+
TuningRecord(gen_id=0, candidate_id=0, to_compile=True, to_benchmark=True)
137+
)
138+
139+
for can_idx, gen_idx in enumerate(sorted_order, start=1):
140+
tr = TuningRecord(
141+
gen_id=gen_idx,
142+
candidate_id=can_idx,
143+
knob=knobs[gen_idx],
144+
)
145+
tuning_records.append(tr)
146+
147+
return tuning_records
148+
149+
150+
def export_record_to_csv(
151+
objects: list[TuningRecord], dest_dir: Path, filename: str = "export.csv"
152+
) -> Path:
153+
if not objects:
154+
return None
155+
156+
rows = []
157+
headers = []
158+
159+
for obj in objects:
160+
row = {}
161+
for k, v in vars(obj).items():
162+
if hasattr(v, "__dict__"):
163+
nested = vars(v)
164+
if nested: # only if it has attrs
165+
for nk, nv in nested.items():
166+
key = f"{k}.{nk}"
167+
row[key] = nv
168+
if key not in headers:
169+
headers.append(key)
170+
else:
171+
# skip empty nested object entirely
172+
continue
173+
else:
174+
row[k] = v
175+
if k not in headers:
176+
headers.append(k)
177+
rows.append(row)
178+
179+
path = os.path.join(dest_dir, filename)
180+
with open(path, "w", newline="", encoding="utf-8") as f:
181+
writer = csv.DictWriter(f, fieldnames=headers)
182+
writer.writeheader()
183+
writer.writerows(rows)
184+
185+
return path

sharktuner/sharktuner/candidate_tuning_records.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

sharktuner/sharktuner/libtuner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
dispatch_constraints,
5050
dispatch_parser,
5151
candidate_ordering,
52-
candidate_tuning_records,
5352
)
5453

5554

@@ -126,7 +125,7 @@ def __init__(self, tuner_context: common.TunerContext):
126125
self.tuner_context = tuner_context
127126
self.candidate_trackers: list[CandidateTracker] = []
128127
self.target_info: Optional[iree_gpu.TargetInfo] = None
129-
self.tuning_records: list[candidate_tuning_records.TuningRecord] = []
128+
self.tuning_records: list[candidate_ordering.TuningRecord] = []
130129

131130
@abstractmethod
132131
def get_iree_compile_flags(self) -> list[str]:
@@ -847,7 +846,7 @@ def generate_candidate_specs(
847846
# Total number of configs = candidates generated + baseline.
848847
assert len(config_specs) == len(solutions) + 1
849848

850-
tuning_client.tuning_records = candidate_tuning_records.init_tuning_records(
849+
tuning_client.tuning_records = candidate_ordering.init_tuning_records(
851850
knobs, sorted_order
852851
)
853852

0 commit comments

Comments
 (0)