We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent de1e9d2 commit 4a32eeaCopy full SHA for 4a32eea
art/estimators/classification/pytorch.py
@@ -420,7 +420,15 @@ def fit( # pylint: disable=W0221
420
self._optimizer.zero_grad()
421
422
# Perform prediction
423
- model_outputs = self._model(i_batch)
+ try:
424
+ model_outputs = self._model(i_batch)
425
+ except ValueError as e:
426
+ if "Expected more than 1 value per channel when training" in str(e):
427
+ logger.exception(
428
+ "Try dropping the last incomplete batch by setting drop_last=True in "
429
+ "method PyTorchClassifier.fit."
430
+ )
431
+ raise e
432
433
# Form the loss function
434
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
0 commit comments