@@ -147,82 +147,3 @@ def predict(
147147 def _fit_classifier (self , x : np .ndarray , y : np .ndarray , batch_size : int , nb_epochs : int , ** kwargs ) -> None :
148148 x = x .astype (ART_NUMPY_DTYPE )
149149 return PyTorchClassifier .fit (self , x , y , batch_size = batch_size , nb_epochs = nb_epochs , ** kwargs )
150-
151- def fit ( # pylint: disable=W0221
152- self ,
153- x : np .ndarray ,
154- y : np .ndarray ,
155- batch_size : int = 128 ,
156- nb_epochs : int = 10 ,
157- training_mode : bool = True ,
158- scheduler : Optional [Any ] = None ,
159- ** kwargs ,
160- ) -> None :
161- """
162- Fit the classifier on the training set `(x, y)`.
163-
164- :param x: Training data.
165- :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
166- shape (nb_samples,).
167- :param batch_size: Size of batches.
168- :param nb_epochs: Number of epochs to use for training.
169- :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
170- :param scheduler: Learning rate scheduler to run at the start of every epoch.
171- :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
172- and providing it takes no effect.
173- """
174- import torch # lgtm [py/repeated-import]
175-
176- # Set model mode
177- self ._model .train (mode = training_mode )
178-
179- if self ._optimizer is None : # pragma: no cover
180- raise ValueError ("An optimizer is needed to train the model, but none for provided." )
181-
182- y = check_and_transform_label_format (y , nb_classes = self .nb_classes )
183-
184- # Apply preprocessing
185- x_preprocessed , y_preprocessed = self ._apply_preprocessing (x , y , fit = True )
186-
187- # Check label shape
188- y_preprocessed = self .reduce_labels (y_preprocessed )
189-
190- num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
191- ind = np .arange (len (x_preprocessed ))
192-
193- # Start training
194- for _ in tqdm (range (nb_epochs )):
195- # Shuffle the examples
196- random .shuffle (ind )
197-
198- # Train for one epoch
199- for m in range (num_batch ):
200- i_batch = np .copy (x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]])
201- i_batch = self .ablator .forward (i_batch )
202-
203- i_batch = torch .from_numpy (i_batch ).to (self ._device )
204- o_batch = torch .from_numpy (y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]).to (self ._device )
205-
206- # Zero the parameter gradients
207- self ._optimizer .zero_grad ()
208-
209- # Perform prediction
210- model_outputs = self ._model (i_batch )
211-
212- # Form the loss function
213- loss = self ._loss (model_outputs [- 1 ], o_batch ) # lgtm [py/call-to-non-callable]
214-
215- # Do training
216- if self ._use_amp : # pragma: no cover
217- from apex import amp # pylint: disable=E0611
218-
219- with amp .scale_loss (loss , self ._optimizer ) as scaled_loss :
220- scaled_loss .backward ()
221-
222- else :
223- loss .backward ()
224-
225- self ._optimizer .step ()
226-
227- if scheduler is not None :
228- scheduler .step ()
0 commit comments