Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit ebf9ff3

Browse files
authored
bug fix for torch ONNX export with >1 input (#284)
1 parent dce81ea commit ebf9ff3

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/sparseml/pytorch/utils/exporter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def export_onnx(
185185
create_parent_dirs(onnx_path)
186186

187187
with torch.no_grad():
188-
out = tensors_module_forward(sample_batch, self._module)
188+
out = tensors_module_forward(
189+
sample_batch, self._module, check_feat_lab_inp=False
190+
)
189191

190192
input_names = None
191193
if isinstance(sample_batch, Tensor):
@@ -194,6 +196,8 @@ def export_onnx(
194196
input_names = [
195197
"input_{}".format(index) for index, _ in enumerate(iter(sample_batch))
196198
]
199+
if isinstance(sample_batch, List):
200+
sample_batch = tuple(sample_batch) # torch.onnx.export requires tuple
197201

198202
output_names = None
199203
if isinstance(out, Tensor):

0 commit comments

Comments
 (0)