Skip to content

Commit ecea646

Browse files
committed
Small fix
1 parent 6692512 commit ecea646

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

sharktuner/sharktuner/candidate_ordering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def init_tuning_records(
149149

150150
def export_record_to_csv(
151151
objects: list[TuningRecord], dest_dir: Path, filename: str = "export.csv"
152-
) -> Path:
152+
) -> Optional[Path]:
153153
if not objects:
154154
return None
155155

@@ -176,7 +176,7 @@ def export_record_to_csv(
176176
headers.append(k)
177177
rows.append(row)
178178

179-
path = os.path.join(dest_dir, filename)
179+
path = Path(os.path.join(dest_dir, filename))
180180
with open(path, "w", newline="", encoding="utf-8") as f:
181181
writer = csv.DictWriter(f, fieldnames=headers)
182182
writer.writeheader()

sharktuner/sharktuner/libtuner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,8 +1343,10 @@ def benchmark(
13431343
for i, handler_res in enumerate(all_candidates_with_speedup, start=1):
13441344
benchmark_res, speedup = handler_res
13451345
cid, _, device_id = benchmark_res
1346-
bas = baseline_handler.get_average_result_us(device_id)
1347-
tuning_client.tuning_records[cid].baseline_benchmark_time_us = round(bas, 2)
1346+
baseline_res = baseline_handler.get_average_result_us(device_id)
1347+
tuning_client.tuning_records[cid].baseline_benchmark_time_us = (
1348+
round(baseline_res, 2) if baseline_res else None
1349+
)
13481350
tuning_client.tuning_records[cid].benchmark_speedup = round(speedup, 5)
13491351
tuning_client.tuning_records[cid].benchmark_rank_order = i
13501352

sharktuner/tests/candidate_ordering_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_reorder_assignments(
138138
== expected_order
139139
)
140140

141-
knobs = [None, None, None]
141+
knobs: list[Optional[common.KnobAssignment]] = [None, None, None]
142142
assert (
143143
candidate_ordering.reorder_assignments(
144144
target_info=target_info,
@@ -164,7 +164,6 @@ def test_init_tuning_records(
164164
) -> None:
165165
sorted_order = [2, 0, 1]
166166
tuning_records = candidate_ordering.init_tuning_records(sample_knobs, sorted_order)
167-
expected: list[candidate_ordering.TuningRecord] = []
168167

169168
expected: list[candidate_ordering.TuningRecord] = [
170169
candidate_ordering.TuningRecord(

0 commit comments

Comments
 (0)