Skip to content

Commit 58c0a82

Browse files
author
Beat Buesser
committed
Add drop_last option to method fit of PyTorchClassifier
Signed-off-by: Beat Buesser <[email protected]>
1 parent c00332d commit 58c0a82

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

art/estimators/classification/pytorch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def fit( # pylint: disable=W0221
362362
batch_size: int = 128,
363363
nb_epochs: int = 10,
364364
training_mode: bool = True,
365+
drop_last: bool = False,
365366
**kwargs,
366367
) -> None:
367368
"""
@@ -373,6 +374,9 @@ def fit( # pylint: disable=W0221
373374
:param batch_size: Size of batches.
374375
:param nb_epochs: Number of epochs to use for training.
375376
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
377+
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch
378+
size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch
379+
will be smaller. (default: ``False``)
376380
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
377381
and providing it takes no effect.
378382
"""
@@ -392,7 +396,11 @@ def fit( # pylint: disable=W0221
392396
# Check label shape
393397
y_preprocessed = self.reduce_labels(y_preprocessed)
394398

395-
num_batch = int(np.floor(len(x_preprocessed) / float(batch_size)))
399+
num_batch = len(x_preprocessed) / float(batch_size)
400+
if drop_last:
401+
num_batch = int(np.floor(num_batch))
402+
else:
403+
num_batch = int(np.ceil(num_batch))
396404
ind = np.arange(len(x_preprocessed))
397405

398406
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)

0 commit comments

Comments
 (0)