@@ -137,7 +137,7 @@ def fit( # pylint: disable=W0221
137137 nb_epochs : int = 10 ,
138138 training_mode : bool = True ,
139139 drop_last : bool = False ,
140- scheduler : Optional [Any ] = None ,
140+ scheduler : Optional ["torch.optim.lr_scheduler._LRScheduler" ] = None ,
141141 ** kwargs ,
142142 ) -> None :
143143 """
@@ -157,6 +157,7 @@ def fit( # pylint: disable=W0221
157157 and providing it takes no effect.
158158 """
159159 import torch
160+ from torch .utils .data import TensorDataset , DataLoader
160161
161162 # Set model mode
162163 self ._model .train (mode = training_mode )
@@ -172,36 +173,28 @@ def fit( # pylint: disable=W0221
172173 # Check label shape
173174 y_preprocessed = self .reduce_labels (y_preprocessed )
174175
175- num_batch = len (x_preprocessed ) / float (batch_size )
176- if drop_last :
177- num_batch = int (np .floor (num_batch ))
178- else :
179- num_batch = int (np .ceil (num_batch ))
180- ind = np .arange (len (x_preprocessed ))
181- std = torch .tensor (self .scale ).to (self ._device )
182-
183- x_preprocessed = torch .from_numpy (x_preprocessed ).to (self ._device )
184- y_preprocessed = torch .from_numpy (y_preprocessed ).to (self ._device )
176+ # Create dataloader
177+ x_tensor = torch .from_numpy (x_preprocessed )
178+ y_tensor = torch .from_numpy (y_preprocessed )
179+ dataset = TensorDataset (x_tensor , y_tensor )
180+ dataloader = DataLoader (dataset = dataset , batch_size = batch_size , shuffle = True , drop_last = drop_last )
185181
186182 # Start training
187183 for _ in tqdm (range (nb_epochs )):
188- # Shuffle the examples
189- random .shuffle (ind )
190-
191- # Train for one epoch
192- for m in range (num_batch ):
193- i_batch = x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
194- o_batch = y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
184+ for x_batch , y_batch in dataloader :
185+ # Move inputs to device
186+ x_batch = x_batch .to (self ._device )
187+ y_batch = y_batch .to (self ._device )
195188
196189 # Add random noise for randomized smoothing
197- i_batch = i_batch + torch .randn_like (i_batch , device = self . _device ) * std
190+ x_batch += torch .randn_like (x_batch ) * self . scale
198191
199192 # Zero the parameter gradients
200193 self ._optimizer .zero_grad ()
201194
202195 # Perform prediction
203196 try :
204- model_outputs = self ._model (i_batch )
197+ model_outputs = self ._model (x_batch )
205198 except ValueError as err :
206199 if "Expected more than 1 value per channel when training" in str (err ):
207200 logger .exception (
@@ -211,7 +204,7 @@ def fit( # pylint: disable=W0221
211204 raise err
212205
213206 # Form the loss function
214- loss = self ._loss (model_outputs [- 1 ], o_batch )
207+ loss = self ._loss (model_outputs [- 1 ], y_batch )
215208
216209 # Do training
217210 if self ._use_amp : # pragma: no cover
0 commit comments