@@ -136,6 +136,7 @@ def fit( # pylint: disable=W0221
136136 batch_size : int = 128 ,
137137 nb_epochs : int = 10 ,
138138 training_mode : bool = True ,
139+ drop_last : bool = False ,
139140 scheduler : Optional [Any ] = None ,
140141 ** kwargs ,
141142 ) -> None :
@@ -148,6 +149,9 @@ def fit( # pylint: disable=W0221
148149 :param batch_size: Size of batches.
149150 :param nb_epochs: Number of epochs to use for training.
150151 :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
152+ :param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
153+ the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
154+ the last batch will be smaller. (default: ``False``)
151155 :param scheduler: Learning rate scheduler to run at the start of every epoch.
152156 :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
153157 and providing it takes no effect.
@@ -168,7 +172,11 @@ def fit( # pylint: disable=W0221
168172 # Check label shape
169173 y_preprocessed = self .reduce_labels (y_preprocessed )
170174
171- num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
175+ num_batch = len (x_preprocessed ) / float (batch_size )
176+ if drop_last :
177+ num_batch = int (np .floor (num_batch ))
178+ else :
179+ num_batch = int (np .ceil (num_batch ))
172180 ind = np .arange (len (x_preprocessed ))
173181 std = torch .tensor (self .scale ).to (self ._device )
174182
@@ -217,6 +225,9 @@ def fit( # pylint: disable=W0221
217225
218226 self ._optimizer .step ()
219227
228+ if scheduler is not None :
229+ scheduler .step ()
230+
220231 def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> np .ndarray : # type: ignore
221232 """
222233 Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations.
0 commit comments