Skip to content

Commit 63e80e9

Browse files
committed
small fixes
1 parent 7f0d1e8 commit 63e80e9

File tree

1 file changed

+2
-1
lines changed
  • models/turbine_models/custom_models/torchbench

1 file changed

+2
-1
lines changed

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import sys
99
import gc
10+
import time
1011

1112
from iree.compiler.ir import Context
1213
from iree import runtime as ireert
@@ -281,7 +282,7 @@ def run_benchmark(device, vmfb_path, weights_path, example_args, model_id, csv_p
281282
if "rocm" in device:
282283
device = "hip" + device.split("rocm")[-1]
283284
mod_runner = vmfbRunner(device, vmfb_path, weights_path)
284-
inputs = [ireert.asdevicearray(mod_runner.config.device, i) for i in example_args]
285+
inputs = [ireert.asdevicearray(mod_runner.config.device, i.clone().detach().cpu()) for i in example_args]
285286
start = time.time()
286287
results = runner.ctx.modules.compiled_torchbench_model["main"](*inputs)
287288
latency = time.time() - start

0 commit comments

Comments
 (0)