Skip to content

Commit 6174edf

Browse files
committed
Make inputs saving work for tuples as well
1 parent 2189675 commit 6174edf

File tree

1 file changed

+6
-2
lines changed
  • models/turbine_models/custom_models/torchbench

1 file changed

+6
-2
lines changed

models/turbine_models/custom_models/torchbench/export.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,18 @@ def export_torchbench_model(
181181

182182
_, model_name, model, forward_args, _ = get_model_and_inputs(model_id, batch_size, tb_dir, tb_args)
183183

184-
for idx, i in enumerate(forward_args.values()):
185-
np.save(f"input{idx}", i.clone().detach().cpu())
186184
if dtype == torch.float16:
187185
model = model.half()
188186
model.to("cuda:0")
189187

190188
if not isinstance(forward_args, dict):
191189
forward_args = [i.type(dtype) for i in forward_args]
190+
for idx, i in enumerate(forward_args):
191+
np.save(f"{model_id}_input{idx}", i.clone().detach().cpu())
192+
else:
193+
for idx, i in enumerate(forward_args.values()):
194+
np.save(f"{model_id}_input{idx}", i.clone().detach().cpu())
195+
192196

193197
mapper = {}
194198
if (external_weights_dir is not None):

0 commit comments

Comments
 (0)