Skip to content

Commit 5aefcc7

Browse files
author
Beat Buesser
committed
Update fit methods
Signed-off-by: Beat Buesser <[email protected]>
1 parent 6e05685 commit 5aefcc7

File tree

2 files changed

+104
-4
lines changed

2 files changed

+104
-4
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,89 @@ def predict(
147147
def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None:
148148
x = x.astype(ART_NUMPY_DTYPE)
149149
return PyTorchClassifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
150+
151+
def fit( # pylint: disable=W0221
152+
self,
153+
x: np.ndarray,
154+
y: np.ndarray,
155+
batch_size: int = 128,
156+
nb_epochs: int = 10,
157+
training_mode: bool = True,
158+
scheduler: Optional[Any] = None,
159+
**kwargs,
160+
) -> None:
161+
"""
162+
Fit the classifier on the training set `(x, y)`.
163+
:param x: Training data.
164+
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
165+
shape (nb_samples,).
166+
:param batch_size: Size of batches.
167+
:param nb_epochs: Number of epochs to use for training.
168+
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
169+
:param scheduler: Learning rate scheduler to run at the start of every epoch.
170+
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
171+
and providing it takes no effect.
172+
"""
173+
import torch # lgtm [py/repeated-import]
174+
175+
# Set model mode
176+
self._model.train(mode=training_mode)
177+
178+
if self._optimizer is None: # pragma: no cover
179+
raise ValueError("An optimizer is needed to train the model, but none for provided.")
180+
181+
y = check_and_transform_label_format(y, nb_classes=self.nb_classes)
182+
183+
# Apply preprocessing
184+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
185+
186+
# Check label shape
187+
y_preprocessed = self.reduce_labels(y_preprocessed)
188+
189+
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
190+
ind = np.arange(len(x_preprocessed))
191+
192+
# Start training
193+
for _ in tqdm(range(nb_epochs)):
194+
# Shuffle the examples
195+
random.shuffle(ind)
196+
197+
# Train for one epoch
198+
for m in range(num_batch):
199+
i_batch = np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]])
200+
i_batch = self.ablator.forward(i_batch)
201+
202+
i_batch = torch.from_numpy(i_batch).to(self._device)
203+
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
204+
205+
# Zero the parameter gradients
206+
self._optimizer.zero_grad()
207+
208+
# Perform prediction
209+
try:
210+
model_outputs = self._model(i_batch)
211+
except ValueError as err:
212+
if "Expected more than 1 value per channel when training" in str(err):
213+
logger.exception(
214+
"Try dropping the last incomplete batch by setting drop_last=True in "
215+
"method PyTorchClassifier.fit."
216+
)
217+
raise err
218+
219+
# Form the loss function
220+
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
221+
222+
# Do training
223+
if self._use_amp: # pragma: no cover
224+
from apex import amp # pylint: disable=E0611
225+
226+
with amp.scale_loss(loss, self._optimizer) as scaled_loss:
227+
scaled_loss.backward()
228+
229+
else:
230+
loss.backward()
231+
232+
self._optimizer.step()
233+
234+
if scheduler is not None:
235+
scheduler.step()

art/estimators/certification/randomized_smoothing/pytorch.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
26+
from typing import List, Optional, Tuple, Union, Any, TYPE_CHECKING
2727

2828
import warnings
2929
import random
@@ -136,6 +136,7 @@ def fit( # pylint: disable=W0221
136136
batch_size: int = 128,
137137
nb_epochs: int = 10,
138138
training_mode: bool = True,
139+
scheduler: Optional[Any] = None,
139140
**kwargs,
140141
) -> None:
141142
"""
@@ -147,6 +148,7 @@ def fit( # pylint: disable=W0221
147148
:param batch_size: Size of batches.
148149
:param nb_epochs: Number of epochs to use for training.
149150
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
151+
:param scheduler: Learning rate scheduler to run at the start of every epoch.
150152
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
151153
and providing it takes no effect.
152154
"""
@@ -169,15 +171,19 @@ def fit( # pylint: disable=W0221
169171
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
170172
ind = np.arange(len(x_preprocessed))
171173
std = torch.tensor(self.scale).to(self._device)
174+
175+
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
176+
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
177+
172178
# Start training
173179
for _ in tqdm(range(nb_epochs)):
174180
# Shuffle the examples
175181
random.shuffle(ind)
176182

177183
# Train for one epoch
178184
for m in range(num_batch):
179-
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
180-
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
185+
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
186+
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
181187

182188
# Add random noise for randomized smoothing
183189
i_batch = i_batch + torch.randn_like(i_batch, device=self._device) * std
@@ -186,7 +192,15 @@ def fit( # pylint: disable=W0221
186192
self._optimizer.zero_grad()
187193

188194
# Perform prediction
189-
model_outputs = self._model(i_batch)
195+
try:
196+
model_outputs = self._model(i_batch)
197+
except ValueError as err:
198+
if "Expected more than 1 value per channel when training" in str(err):
199+
logger.exception(
200+
"Try dropping the last incomplete batch by setting drop_last=True in "
201+
"method PyTorchClassifier.fit."
202+
)
203+
raise err
190204

191205
# Form the loss function
192206
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]

0 commit comments

Comments
 (0)