Skip to content

Commit 6d98749

Browse files
authored
Merge pull request #1678 from GiulioZizzo/randomised_smoothing_fix
Randomised smoothing fix
2 parents 67fa652 + f6b4f91 commit 6d98749

File tree

4 files changed

+120
-24
lines changed

4 files changed

+120
-24
lines changed

art/estimators/certification/randomized_smoothing/numpy.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@
2525
import logging
2626
from typing import List, Union, TYPE_CHECKING, Tuple
2727

28+
import warnings
2829
import numpy as np
2930

31+
from art.config import ART_NUMPY_DTYPE
3032
from art.estimators.estimator import BaseEstimator, LossGradientsMixin, NeuralNetworkMixin
3133
from art.estimators.certification.randomized_smoothing.randomized_smoothing import RandomizedSmoothingMixin
3234
from art.estimators.classification import ClassifierMixin, ClassGradientsMixin
35+
from art.defences.preprocessor.gaussian_augmentation import GaussianAugmentation
3336

3437
if TYPE_CHECKING:
3538
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
@@ -69,6 +72,12 @@ def __init__(
6972
:param scale: Standard deviation of Gaussian noise added.
7073
:param alpha: The failure probability of smoothing
7174
"""
75+
if classifier.preprocessing_defences is not None:
76+
warnings.warn(
77+
"\n With the current backend Gaussian noise will be added by Randomized Smoothing "
78+
"BEFORE the application of preprocessing defences. Please ensure this conforms to your use case.\n"
79+
)
80+
7281
super().__init__(
7382
model=classifier.model,
7483
channels_first=classifier.channels_first,
@@ -112,7 +121,12 @@ def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epoc
112121
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
113122
and providing it takes no effect.
114123
"""
115-
return self.classifier.fit(x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
124+
125+
g_a = GaussianAugmentation(sigma=self.scale, augmentation=False)
126+
for _ in range(nb_epochs):
127+
x_rs, _ = g_a(x)
128+
x_rs = x_rs.astype(ART_NUMPY_DTYPE)
129+
self.classifier.fit(x_rs, y, batch_size=batch_size, nb_epochs=1, **kwargs)
116130

117131
def loss_gradient( # pylint: disable=W0221
118132
self, x: np.ndarray, y: np.ndarray, training_mode: bool = False, **kwargs

art/estimators/certification/randomized_smoothing/pytorch.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@
2525
import logging
2626
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
2727

28+
import warnings
29+
import random
30+
from tqdm import tqdm
2831
import numpy as np
2932

3033
from art.config import ART_NUMPY_DTYPE
3134
from art.estimators.classification.pytorch import PyTorchClassifier
3235
from art.estimators.certification.randomized_smoothing.randomized_smoothing import RandomizedSmoothingMixin
36+
from art.utils import check_and_transform_label_format
3337

3438
if TYPE_CHECKING:
3539
# pylint: disable=C0412
@@ -94,6 +98,12 @@ def __init__(
9498
:param scale: Standard deviation of Gaussian noise added.
9599
:param alpha: The failure probability of smoothing.
96100
"""
101+
if preprocessing_defences is not None:
102+
warnings.warn(
103+
"\n With the current backend (Pytorch) Gaussian noise will be added by Randomized Smoothing "
104+
"AFTER the application of preprocessing defences. Please ensure this conforms to your use case.\n"
105+
)
106+
97107
super().__init__(
98108
model=model,
99109
loss=loss,
@@ -126,26 +136,72 @@ def fit( # pylint: disable=W0221
126136
batch_size: int = 128,
127137
nb_epochs: int = 10,
128138
training_mode: bool = True,
129-
**kwargs
130-
):
139+
**kwargs,
140+
) -> None:
131141
"""
132142
Fit the classifier on the training set `(x, y)`.
133143
134144
:param x: Training data.
135-
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or indices of shape
136-
(nb_samples,).
137-
:param batch_size: Batch size.
138-
:key nb_epochs: Number of epochs to use for training
145+
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
146+
shape (nb_samples,).
147+
:param batch_size: Size of batches.
148+
:param nb_epochs: Number of epochs to use for training.
149+
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
139150
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
140151
and providing it takes no effect.
141-
:type kwargs: `dict`
142-
:return: `None`
143152
"""
153+
import torch # lgtm [py/repeated-import]
144154

145155
# Set model mode
146156
self._model.train(mode=training_mode)
147157

148-
RandomizedSmoothingMixin.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
158+
if self._optimizer is None: # pragma: no cover
159+
raise ValueError("An optimizer is needed to train the model, but none for provided.")
160+
161+
y = check_and_transform_label_format(y, self.nb_classes)
162+
163+
# Apply preprocessing
164+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
165+
166+
# Check label shape
167+
y_preprocessed = self.reduce_labels(y_preprocessed)
168+
169+
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
170+
ind = np.arange(len(x_preprocessed))
171+
std = torch.tensor(self.scale).to(self._device)
172+
# Start training
173+
for _ in tqdm(range(nb_epochs)):
174+
# Shuffle the examples
175+
random.shuffle(ind)
176+
177+
# Train for one epoch
178+
for m in range(num_batch):
179+
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
180+
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
181+
182+
# Add random noise for randomized smoothing
183+
i_batch = i_batch + torch.randn_like(i_batch, device=self._device) * std
184+
185+
# Zero the parameter gradients
186+
self._optimizer.zero_grad()
187+
188+
# Perform prediction
189+
model_outputs = self._model(i_batch)
190+
191+
# Form the loss function
192+
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
193+
194+
# Do training
195+
if self._use_amp: # pragma: no cover
196+
from apex import amp # pylint: disable=E0611
197+
198+
with amp.scale_loss(loss, self._optimizer) as scaled_loss:
199+
scaled_loss.backward()
200+
201+
else:
202+
loss.backward()
203+
204+
self._optimizer.step()
149205

150206
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
151207
"""

art/estimators/certification/randomized_smoothing/randomized_smoothing.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from tqdm.auto import tqdm
3232

3333
from art.config import ART_NUMPY_DTYPE
34-
from art.defences.preprocessor.gaussian_augmentation import GaussianAugmentation
3534

3635
logger = logging.getLogger(__name__)
3736

@@ -141,9 +140,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
141140
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
142141
and providing it takes no effect.
143142
"""
144-
g_a = GaussianAugmentation(sigma=self.scale, augmentation=False)
145-
x_rs, _ = g_a(x)
146-
self._fit_classifier(x_rs, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
143+
self._fit_classifier(x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
147144

148145
def certify(self, x: np.ndarray, n: int, batch_size: int = 32) -> Tuple[np.ndarray, np.ndarray]:
149146
"""

art/estimators/certification/randomized_smoothing/tensorflow.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@
2525
import logging
2626
from typing import Callable, List, Optional, Tuple, Union, TYPE_CHECKING
2727

28+
import warnings
29+
from tqdm import tqdm
2830
import numpy as np
2931

3032
from art.estimators.classification.tensorflow import TensorFlowV2Classifier
3133
from art.estimators.certification.randomized_smoothing.randomized_smoothing import RandomizedSmoothingMixin
34+
from art.utils import check_and_transform_label_format
3235

3336
if TYPE_CHECKING:
3437
# pylint: disable=C0412
@@ -91,6 +94,12 @@ def __init__(
9194
:param scale: Standard deviation of Gaussian noise added.
9295
:param alpha: The failure probability of smoothing.
9396
"""
97+
if preprocessing_defences is not None:
98+
warnings.warn(
99+
"\nWith the current backend (Tensorflow), Gaussian noise will be added by Randomized Smoothing "
100+
"AFTER the application of preprocessing defences. Please ensure this conforms to your use case.\n"
101+
)
102+
94103
super().__init__(
95104
model=model,
96105
nb_classes=nb_classes,
@@ -113,21 +122,41 @@ def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: boo
113122
def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None:
114123
return TensorFlowV2Classifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
115124

116-
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs):
125+
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
117126
"""
118127
Fit the classifier on the training set `(x, y)`.
119128
120129
:param x: Training data.
121-
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or indices of shape
122-
(nb_samples,).
123-
:param batch_size: Batch size.
124-
:key nb_epochs: Number of epochs to use for training
125-
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
126-
and providing it takes no effect.
127-
:type kwargs: `dict`
128-
:return: `None`
130+
:param y: Labels, one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
131+
shape (nb_samples,).
132+
:param batch_size: Size of batches.
133+
:param nb_epochs: Number of epochs to use for training.
134+
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for
135+
TensorFlow and providing it takes no effect.
129136
"""
130-
RandomizedSmoothingMixin.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
137+
import tensorflow as tf # lgtm [py/repeated-import]
138+
139+
if self._train_step is None: # pragma: no cover
140+
raise TypeError(
141+
"The training function `train_step` is required for fitting a model but it has not been " "defined."
142+
)
143+
144+
y = check_and_transform_label_format(y, self.nb_classes)
145+
146+
# Apply preprocessing
147+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
148+
149+
# Check label shape
150+
if self._reduce_labels:
151+
y_preprocessed = np.argmax(y_preprocessed, axis=1)
152+
153+
train_ds = tf.data.Dataset.from_tensor_slices((x_preprocessed, y_preprocessed)).shuffle(10000).batch(batch_size)
154+
155+
for _ in tqdm(range(nb_epochs)):
156+
for images, labels in train_ds:
157+
# Add random noise for randomized smoothing
158+
images += tf.random.normal(shape=images.shape, mean=0.0, stddev=self.scale)
159+
self._train_step(self.model, images, labels)
131160

132161
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
133162
"""

0 commit comments

Comments
 (0)