Skip to content

Commit 72b5c0a

Browse files
committed
Better CSV generation
1 parent d524e10 commit 72b5c0a

File tree

1 file changed

+11
-4
lines changed
  • models/turbine_models/custom_models/torchbench

1 file changed

+11
-4
lines changed

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,19 @@ def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_p
294294
results, iter_latency = _run_iter(mod_runner, inputs)
295295
iter_latencies.append(iter_latency)
296296
avg_latency = sum(iter_latencies) / len(iter_latencies)
297-
with open(csv_path, "w") as csvfile:
298-
fieldnames = ["model", "avg_latency"]
299-
data = [{"model": model_id, "avg_latency": avg_latency}]
297+
it_per_sec = 1 / avg_latency
298+
299+
needs_header = True
300+
if os.path.exists(csv_path):
301+
needs_header = False
302+
with open(csv_path, "a") as csvfile:
303+
fieldnames = ["model", "avg_latency", "avg_iter_per_sec"]
304+
data = [{"model": model_id, "avg_latency": avg_latency, "avg_iter_per_sec": it_per_sec}]
300305
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
301-
writer.writeheader()
306+
if needs_header:
307+
writer.writeheader()
302308
writer.writerows(data)
309+
print(data)
303310

304311

305312
def torch_to_iree(iree_runner, example_args):

0 commit comments

Comments
 (0)