Skip to content

Commit 6e05685

Browse files
author
Beat Buesser
committed
Fix style checks
Signed-off-by: Beat Buesser <[email protected]>
1 parent 4a32eea commit 6e05685

File tree

2 files changed

+8
-82
lines changed

2 files changed

+8
-82
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -147,82 +147,3 @@ def predict(
147147
def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None:
148148
x = x.astype(ART_NUMPY_DTYPE)
149149
return PyTorchClassifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
150-
151-
def fit( # pylint: disable=W0221
152-
self,
153-
x: np.ndarray,
154-
y: np.ndarray,
155-
batch_size: int = 128,
156-
nb_epochs: int = 10,
157-
training_mode: bool = True,
158-
scheduler: Optional[Any] = None,
159-
**kwargs,
160-
) -> None:
161-
"""
162-
Fit the classifier on the training set `(x, y)`.
163-
164-
:param x: Training data.
165-
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
166-
shape (nb_samples,).
167-
:param batch_size: Size of batches.
168-
:param nb_epochs: Number of epochs to use for training.
169-
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
170-
:param scheduler: Learning rate scheduler to run at the start of every epoch.
171-
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
172-
and providing it takes no effect.
173-
"""
174-
import torch # lgtm [py/repeated-import]
175-
176-
# Set model mode
177-
self._model.train(mode=training_mode)
178-
179-
if self._optimizer is None: # pragma: no cover
180-
raise ValueError("An optimizer is needed to train the model, but none for provided.")
181-
182-
y = check_and_transform_label_format(y, nb_classes=self.nb_classes)
183-
184-
# Apply preprocessing
185-
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
186-
187-
# Check label shape
188-
y_preprocessed = self.reduce_labels(y_preprocessed)
189-
190-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
191-
ind = np.arange(len(x_preprocessed))
192-
193-
# Start training
194-
for _ in tqdm(range(nb_epochs)):
195-
# Shuffle the examples
196-
random.shuffle(ind)
197-
198-
# Train for one epoch
199-
for m in range(num_batch):
200-
i_batch = np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]])
201-
i_batch = self.ablator.forward(i_batch)
202-
203-
i_batch = torch.from_numpy(i_batch).to(self._device)
204-
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
205-
206-
# Zero the parameter gradients
207-
self._optimizer.zero_grad()
208-
209-
# Perform prediction
210-
model_outputs = self._model(i_batch)
211-
212-
# Form the loss function
213-
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
214-
215-
# Do training
216-
if self._use_amp: # pragma: no cover
217-
from apex import amp # pylint: disable=E0611
218-
219-
with amp.scale_loss(loss, self._optimizer) as scaled_loss:
220-
scaled_loss.backward()
221-
222-
else:
223-
loss.backward()
224-
225-
self._optimizer.step()
226-
227-
if scheduler is not None:
228-
scheduler.step()

art/estimators/classification/pytorch.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def fit( # pylint: disable=W0221
363363
nb_epochs: int = 10,
364364
training_mode: bool = True,
365365
drop_last: bool = False,
366+
scheduler: Optional[Any] = None,
366367
**kwargs,
367368
) -> None:
368369
"""
@@ -377,6 +378,7 @@ def fit( # pylint: disable=W0221
377378
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
378379
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
379380
the last batch will be smaller. (default: ``False``)
381+
:param scheduler: Learning rate scheduler to run at the start of every epoch.
380382
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
381383
and providing it takes no effect.
382384
"""
@@ -422,13 +424,13 @@ def fit( # pylint: disable=W0221
422424
# Perform prediction
423425
try:
424426
model_outputs = self._model(i_batch)
425-
except ValueError as e:
426-
if "Expected more than 1 value per channel when training" in str(e):
427+
except ValueError as err:
428+
if "Expected more than 1 value per channel when training" in str(err):
427429
logger.exception(
428430
"Try dropping the last incomplete batch by setting drop_last=True in "
429431
"method PyTorchClassifier.fit."
430432
)
431-
raise e
433+
raise err
432434

433435
# Form the loss function
434436
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
@@ -445,6 +447,9 @@ def fit( # pylint: disable=W0221
445447

446448
self._optimizer.step()
447449

450+
if scheduler is not None:
451+
scheduler.step()
452+
448453
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
449454
"""
450455
Fit the classifier using the generator that yields batches as specified.

0 commit comments

Comments
 (0)