File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed
art/estimators/classification Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -362,6 +362,7 @@ def fit( # pylint: disable=W0221
362362 batch_size : int = 128 ,
363363 nb_epochs : int = 10 ,
364364 training_mode : bool = True ,
365+ drop_last : bool = False ,
365366 ** kwargs ,
366367 ) -> None :
367368 """
@@ -373,6 +374,9 @@ def fit( # pylint: disable=W0221
373374 :param batch_size: Size of batches.
374375 :param nb_epochs: Number of epochs to use for training.
375376 :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
377+ :param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch
378+ size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch
379+ will be smaller. (default: ``False``)
376380 :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
377381 and providing it takes no effect.
378382 """
@@ -392,7 +396,11 @@ def fit( # pylint: disable=W0221
392396 # Check label shape
393397 y_preprocessed = self .reduce_labels (y_preprocessed )
394398
395- num_batch = int (np .floor (len (x_preprocessed ) / float (batch_size )))
399+ num_batch = len (x_preprocessed ) / float (batch_size )
400+ if drop_last :
401+ num_batch = int (np .floor (num_batch ))
402+ else :
403+ num_batch = int (np .ceil (num_batch ))
396404 ind = np .arange (len (x_preprocessed ))
397405
398406 x_preprocessed = torch .from_numpy (x_preprocessed ).to (self ._device )
You can’t perform that action at this time.
0 commit comments