Skip to content

Commit 4a32eea

Browse files
author
Beat Buesser
committed
Add try-except
Signed-off-by: Beat Buesser <[email protected]>
1 parent de1e9d2 commit 4a32eea

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

art/estimators/classification/pytorch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,15 @@ def fit( # pylint: disable=W0221
420420
self._optimizer.zero_grad()
421421

422422
# Perform prediction
423-
model_outputs = self._model(i_batch)
423+
try:
424+
model_outputs = self._model(i_batch)
425+
except ValueError as e:
426+
if "Expected more than 1 value per channel when training" in str(e):
427+
logger.exception(
428+
"Try dropping the last incomplete batch by setting drop_last=True in "
429+
"method PyTorchClassifier.fit."
430+
)
431+
raise e
424432

425433
# Form the loss function
426434
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]

0 commit comments

Comments
 (0)