@@ -147,3 +147,89 @@ def predict(
147147 def _fit_classifier (self , x : np .ndarray , y : np .ndarray , batch_size : int , nb_epochs : int , ** kwargs ) -> None :
148148 x = x .astype (ART_NUMPY_DTYPE )
149149 return PyTorchClassifier .fit (self , x , y , batch_size = batch_size , nb_epochs = nb_epochs , ** kwargs )
150+
151+ def fit ( # pylint: disable=W0221
152+ self ,
153+ x : np .ndarray ,
154+ y : np .ndarray ,
155+ batch_size : int = 128 ,
156+ nb_epochs : int = 10 ,
157+ training_mode : bool = True ,
158+ scheduler : Optional [Any ] = None ,
159+ ** kwargs ,
160+ ) -> None :
161+ """
162+ Fit the classifier on the training set `(x, y)`.
163+ :param x: Training data.
164+ :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
165+ shape (nb_samples,).
166+ :param batch_size: Size of batches.
167+ :param nb_epochs: Number of epochs to use for training.
168+ :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
169+ :param scheduler: Learning rate scheduler to run at the start of every epoch.
170+ :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
171+ and providing it takes no effect.
172+ """
173+ import torch # lgtm [py/repeated-import]
174+
175+ # Set model mode
176+ self ._model .train (mode = training_mode )
177+
178+ if self ._optimizer is None : # pragma: no cover
179+ raise ValueError ("An optimizer is needed to train the model, but none for provided." )
180+
181+ y = check_and_transform_label_format (y , nb_classes = self .nb_classes )
182+
183+ # Apply preprocessing
184+ x_preprocessed , y_preprocessed = self ._apply_preprocessing (x , y , fit = True )
185+
186+ # Check label shape
187+ y_preprocessed = self .reduce_labels (y_preprocessed )
188+
189+ num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
190+ ind = np .arange (len (x_preprocessed ))
191+
192+ # Start training
193+ for _ in tqdm (range (nb_epochs )):
194+ # Shuffle the examples
195+ random .shuffle (ind )
196+
197+ # Train for one epoch
198+ for m in range (num_batch ):
199+ i_batch = np .copy (x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]])
200+ i_batch = self .ablator .forward (i_batch )
201+
202+ i_batch = torch .from_numpy (i_batch ).to (self ._device )
203+ o_batch = torch .from_numpy (y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]).to (self ._device )
204+
205+ # Zero the parameter gradients
206+ self ._optimizer .zero_grad ()
207+
208+ # Perform prediction
209+ try :
210+ model_outputs = self ._model (i_batch )
211+ except ValueError as err :
212+ if "Expected more than 1 value per channel when training" in str (err ):
213+ logger .exception (
214+ "Try dropping the last incomplete batch by setting drop_last=True in "
215+ "method PyTorchClassifier.fit."
216+ )
217+ raise err
218+
219+ # Form the loss function
220+ loss = self ._loss (model_outputs [- 1 ], o_batch ) # lgtm [py/call-to-non-callable]
221+
222+ # Do training
223+ if self ._use_amp : # pragma: no cover
224+ from apex import amp # pylint: disable=E0611
225+
226+ with amp .scale_loss (loss , self ._optimizer ) as scaled_loss :
227+ scaled_loss .backward ()
228+
229+ else :
230+ loss .backward ()
231+
232+ self ._optimizer .step ()
233+
234+ if scheduler is not None :
235+ scheduler .step ()
0 commit comments