Skip to content

Commit 520bf0b

Browse files
authored
Merge pull request #1729 from GiulioZizzo/derandomized_smoothing
Derandomized smoothing
2 parents dff2bde + cda41c5 commit 520bf0b

File tree

14 files changed

+1260
-0
lines changed

14 files changed

+1260
-0
lines changed

art/estimators/certification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import importlib
55
from art.estimators.certification import randomized_smoothing
6+
from art.estimators.certification import derandomized_smoothing
67

78
if importlib.util.find_spec("torch") is not None:
89
from art.estimators.certification import deep_z
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
DeRandomized smoothing estimators.
3+
"""
4+
from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin
5+
from art.estimators.certification.derandomized_smoothing.pytorch import PyTorchDeRandomizedSmoothing
6+
from art.estimators.certification.derandomized_smoothing.tensorflow import TensorFlowV2DeRandomizedSmoothing

art/estimators/certification/derandomized_smoothing/derandomized_smoothing.py

Lines changed: 454 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2022
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements (De)Randomized Smoothing for Certifiable Defense against Patch Attacks
20+
21+
| Paper link: https://arxiv.org/abs/2002.10733
22+
"""
23+
24+
from __future__ import absolute_import, division, print_function, unicode_literals
25+
26+
import logging
27+
from typing import List, Optional, Tuple, Union, Any, TYPE_CHECKING
28+
import random
29+
30+
import numpy as np
31+
from tqdm import tqdm
32+
33+
from art.config import ART_NUMPY_DTYPE
34+
from art.estimators.classification.pytorch import PyTorchClassifier
35+
from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin
36+
from art.utils import check_and_transform_label_format
37+
38+
if TYPE_CHECKING:
39+
# pylint: disable=C0412
40+
import torch
41+
42+
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
43+
from art.defences.preprocessor import Preprocessor
44+
from art.defences.postprocessor import Postprocessor
45+
46+
logger = logging.getLogger(__name__)
47+
48+
49+
class PyTorchDeRandomizedSmoothing(DeRandomizedSmoothingMixin, PyTorchClassifier):
50+
"""
51+
Implementation of (De)Randomized Smoothing applied to classifier predictions as introduced
52+
in Levine et al. (2020).
53+
54+
| Paper link: https://arxiv.org/abs/2002.10733
55+
"""
56+
57+
estimator_params = PyTorchClassifier.estimator_params + ["ablation_type", "ablation_size", "threshold", "logits"]
58+
59+
def __init__(
60+
self,
61+
model: "torch.nn.Module",
62+
loss: "torch.nn.modules.loss._Loss",
63+
input_shape: Tuple[int, ...],
64+
nb_classes: int,
65+
ablation_type: str,
66+
ablation_size: int,
67+
threshold: float,
68+
logits: bool,
69+
optimizer: Optional["torch.optim.Optimizer"] = None, # type: ignore
70+
channels_first: bool = True,
71+
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
72+
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
73+
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
74+
preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0),
75+
device_type: str = "gpu",
76+
):
77+
"""
78+
Create a derandomized smoothing classifier.
79+
80+
:param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits
81+
output should be preferred where possible to ensure attack efficiency.
82+
:param loss: The loss function for which to compute gradients for training. The target label must be raw
83+
categorical, i.e. not converted to one-hot encoding.
84+
:param input_shape: The shape of one input instance.
85+
:param nb_classes: The number of classes of the model.
86+
:param ablation_type: The type of ablation to perform, must be either "column" or "block"
87+
:param ablation_size: The size of the data portion to retain after ablation. Will be a column of size N for
88+
"column" ablation type or a NxN square for ablation of type "block"
89+
:param threshold: The minimum threshold to count a prediction.
90+
:param logits: if the model returns logits or normalized probabilities
91+
:param optimizer: The optimizer used to train the classifier.
92+
:param channels_first: Set channels first or last.
93+
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
94+
maximum values allowed for features. If floats are provided, these will be used as the range of all
95+
features. If arrays are provided, each value will be considered the bound for a feature, thus
96+
the shape of clip values needs to match the total number of features.
97+
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
98+
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
99+
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
100+
used for data preprocessing. The first value will be subtracted from the input. The input will then
101+
be divided by the second one.
102+
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
103+
"""
104+
super().__init__(
105+
model=model,
106+
loss=loss,
107+
input_shape=input_shape,
108+
nb_classes=nb_classes,
109+
optimizer=optimizer,
110+
channels_first=channels_first,
111+
clip_values=clip_values,
112+
preprocessing_defences=preprocessing_defences,
113+
postprocessing_defences=postprocessing_defences,
114+
preprocessing=preprocessing,
115+
device_type=device_type,
116+
ablation_type=ablation_type,
117+
ablation_size=ablation_size,
118+
threshold=threshold,
119+
logits=logits,
120+
)
121+
122+
def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: bool, **kwargs) -> np.ndarray:
123+
import torch # lgtm [py/repeated-import]
124+
125+
x = x.astype(ART_NUMPY_DTYPE)
126+
outputs = PyTorchClassifier.predict(self, x=x, batch_size=batch_size, training_mode=training_mode, **kwargs)
127+
128+
if not self.logits:
129+
return np.asarray((outputs >= self.threshold))
130+
return np.asarray(
131+
(torch.nn.functional.softmax(torch.from_numpy(outputs), dim=1) >= self.threshold).type(torch.int)
132+
)
133+
134+
def predict(
135+
self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs
136+
) -> np.ndarray: # type: ignore
137+
"""
138+
Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations.
139+
140+
:param x: Input samples.
141+
:param batch_size: Batch size.
142+
:param training_mode: if to run the classifier in training mode
143+
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
144+
"""
145+
return DeRandomizedSmoothingMixin.predict(self, x, batch_size=batch_size, training_mode=training_mode, **kwargs)
146+
147+
def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None:
148+
x = x.astype(ART_NUMPY_DTYPE)
149+
return PyTorchClassifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)
150+
151+
def fit( # pylint: disable=W0221
152+
self,
153+
x: np.ndarray,
154+
y: np.ndarray,
155+
batch_size: int = 128,
156+
nb_epochs: int = 10,
157+
training_mode: bool = True,
158+
scheduler: Optional[Any] = None,
159+
**kwargs,
160+
) -> None:
161+
"""
162+
Fit the classifier on the training set `(x, y)`.
163+
164+
:param x: Training data.
165+
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
166+
shape (nb_samples,).
167+
:param batch_size: Size of batches.
168+
:param nb_epochs: Number of epochs to use for training.
169+
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
170+
:param scheduler: Learning rate scheduler to run at the start of every epoch.
171+
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
172+
and providing it takes no effect.
173+
"""
174+
import torch # lgtm [py/repeated-import]
175+
176+
# Set model mode
177+
self._model.train(mode=training_mode)
178+
179+
if self._optimizer is None: # pragma: no cover
180+
raise ValueError("An optimizer is needed to train the model, but none for provided.")
181+
182+
y = check_and_transform_label_format(y, self.nb_classes)
183+
184+
# Apply preprocessing
185+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
186+
187+
# Check label shape
188+
y_preprocessed = self.reduce_labels(y_preprocessed)
189+
190+
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
191+
ind = np.arange(len(x_preprocessed))
192+
193+
# Start training
194+
for _ in tqdm(range(nb_epochs)):
195+
# Shuffle the examples
196+
random.shuffle(ind)
197+
198+
# Train for one epoch
199+
for m in range(num_batch):
200+
i_batch = np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]])
201+
i_batch = self.ablator.forward(i_batch)
202+
203+
i_batch = torch.from_numpy(i_batch).to(self._device)
204+
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
205+
206+
# Zero the parameter gradients
207+
self._optimizer.zero_grad()
208+
209+
# Perform prediction
210+
model_outputs = self._model(i_batch)
211+
212+
# Form the loss function
213+
loss = self._loss(model_outputs[-1], o_batch) # lgtm [py/call-to-non-callable]
214+
215+
# Do training
216+
if self._use_amp: # pragma: no cover
217+
from apex import amp # pylint: disable=E0611
218+
219+
with amp.scale_loss(loss, self._optimizer) as scaled_loss:
220+
scaled_loss.backward()
221+
222+
else:
223+
loss.backward()
224+
225+
self._optimizer.step()
226+
227+
if scheduler is not None:
228+
scheduler.step()

0 commit comments

Comments
 (0)