Skip to content

Commit 36bc03d

Browse files
authored
Merge pull request #2131 from Zaid-Hameed/trades_adv
TRADES adversarial training protocol
2 parents bbb92cf + 3fe948d commit 36bc03d

File tree

5 files changed

+691
-0
lines changed

5 files changed

+691
-0
lines changed

art/defences/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88
from art.defences.trainer.adversarial_trainer_madry_pgd import AdversarialTrainerMadryPGD
99
from art.defences.trainer.adversarial_trainer_fbf import AdversarialTrainerFBF
1010
from art.defences.trainer.adversarial_trainer_fbf_pytorch import AdversarialTrainerFBFPyTorch
11+
from art.defences.trainer.adversarial_trainer_trades import AdversarialTrainerTRADES
12+
from art.defences.trainer.adversarial_trainer_trades_pytorch import AdversarialTrainerTRADESPyTorch
1113
from art.defences.trainer.dp_instahide_trainer import DPInstaHideTrainer
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements adversarial training with TRADES protocol.
20+
21+
| Paper link: https://proceedings.mlr.press/v97/zhang19p.html
22+
23+
| It was noted that this protocol uses a modified loss called TRADES loss which is a combination of cross entropy
24+
loss on clean data and KL divergence loss between clean data and adversarial data. Consequently, framework specific
25+
implementations are being provided in ART.
26+
"""
27+
from __future__ import absolute_import, division, print_function, unicode_literals
28+
29+
import abc
30+
from typing import Optional, Tuple, TYPE_CHECKING
31+
32+
import numpy as np
33+
34+
from art.defences.trainer.trainer import Trainer
35+
from art.attacks.attack import EvasionAttack
36+
from art.data_generators import DataGenerator
37+
38+
if TYPE_CHECKING:
39+
from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE
40+
41+
42+
class AdversarialTrainerTRADES(Trainer, abc.ABC):
43+
"""
44+
This is abstract class for different backend-specific implementations of TRADES protocol
45+
for adversarial training.
46+
47+
| Paper link: https://proceedings.mlr.press/v97/zhang19p.html
48+
"""
49+
50+
def __init__(
51+
self,
52+
classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
53+
attack: EvasionAttack,
54+
beta: float = 6.0,
55+
):
56+
"""
57+
Create an :class:`.AdversarialTrainerTRADES` instance.
58+
59+
:param classifier: Model to train adversarially.
60+
:param attack: attack to use for data augmentation in adversarial training
61+
:param beta: The scaling factor controlling tradeoff between clean loss and adversarial loss
62+
"""
63+
self._attack = attack
64+
self._beta = beta
65+
super().__init__(classifier)
66+
67+
@abc.abstractmethod
68+
def fit( # pylint: disable=W0221
69+
self,
70+
x: np.ndarray,
71+
y: np.ndarray,
72+
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
73+
batch_size: int = 128,
74+
nb_epochs: int = 20,
75+
**kwargs
76+
):
77+
"""
78+
Train a model adversarially with TRADES. See class documentation for more information on the exact procedure.
79+
80+
:param x: Training set.
81+
:param y: Labels for the training set.
82+
:param validation_data: Tuple consisting of validation data, (x_val, y_val)
83+
:param batch_size: Size of batches.
84+
:param nb_epochs: Number of epochs to use for trainings.
85+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
86+
the target classifier.
87+
"""
88+
raise NotImplementedError
89+
90+
@abc.abstractmethod
91+
def fit_generator(self, generator: DataGenerator, nb_epochs: int = 20, **kwargs):
92+
"""
93+
Train a model adversarially using a data generator.
94+
See class documentation for more information on the exact procedure.
95+
96+
:param generator: Data generator.
97+
:param nb_epochs: Number of epochs to use for trainings.
98+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
99+
the target classifier.
100+
"""
101+
raise NotImplementedError
102+
103+
def predict(self, x: np.ndarray, **kwargs) -> np.ndarray:
104+
"""
105+
Perform prediction using the adversarially trained classifier.
106+
107+
:param x: Input samples.
108+
:param kwargs: Other parameters to be passed on to the `predict` function of the classifier.
109+
:return: Predictions for test set.
110+
"""
111+
return self._classifier.predict(x, **kwargs)
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This is a PyTorch implementation of the TRADES protocol.
20+
21+
| Paper link: https://proceedings.mlr.press/v97/zhang19p.html
22+
"""
23+
from __future__ import absolute_import, division, print_function, unicode_literals
24+
25+
import logging
26+
import time
27+
from typing import Optional, Tuple, TYPE_CHECKING
28+
29+
import numpy as np
30+
from tqdm.auto import trange
31+
32+
from art.defences.trainer.adversarial_trainer_trades import AdversarialTrainerTRADES
33+
from art.estimators.classification.pytorch import PyTorchClassifier
34+
from art.data_generators import DataGenerator
35+
from art.attacks.attack import EvasionAttack
36+
37+
if TYPE_CHECKING:
38+
import torch
39+
40+
logger = logging.getLogger(__name__)
41+
EPS = 1e-8
42+
43+
44+
class AdversarialTrainerTRADESPyTorch(AdversarialTrainerTRADES):
45+
"""
46+
Class performing adversarial training following TRADES protocol.
47+
48+
| Paper link: https://proceedings.mlr.press/v97/zhang19p.html
49+
"""
50+
51+
def __init__(self, classifier: PyTorchClassifier, attack: EvasionAttack, beta: float):
52+
"""
53+
Create an :class:`.AdversarialTrainerTRADESPyTorch` instance.
54+
55+
:param classifier: Model to train adversarially.
56+
:param attack: attack to use for data augmentation in adversarial training
57+
:param beta: The scaling factor controlling tradeoff between clean loss and adversarial loss
58+
"""
59+
super().__init__(classifier, attack, beta)
60+
self._classifier: PyTorchClassifier
61+
self._attack: EvasionAttack
62+
self._beta: float
63+
64+
def fit(
65+
self,
66+
x: np.ndarray,
67+
y: np.ndarray,
68+
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
69+
batch_size: int = 128,
70+
nb_epochs: int = 20,
71+
scheduler: "torch.optim.lr_scheduler._LRScheduler" = None,
72+
**kwargs
73+
): # pylint: disable=W0221
74+
"""
75+
Train a model adversarially with TRADES protocol.
76+
See class documentation for more information on the exact procedure.
77+
78+
:param x: Training set.
79+
:param y: Labels for the training set.
80+
:param validation_data: Tuple consisting of validation data, (x_val, y_val)
81+
:param batch_size: Size of batches.
82+
:param nb_epochs: Number of epochs to use for trainings.
83+
:param scheduler: Learning rate scheduler to run at the end of every epoch.
84+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
85+
the target classifier.
86+
"""
87+
import torch
88+
89+
logger.info("Performing adversarial training with TRADES protocol")
90+
# pylint: disable=W0212
91+
if (scheduler is not None) and (
92+
not isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler)
93+
): # pylint: enable=W0212
94+
raise ValueError("Invalid Pytorch scheduler is provided for adversarial training.")
95+
96+
nb_batches = int(np.ceil(len(x) / batch_size))
97+
ind = np.arange(len(x))
98+
99+
logger.info("Adversarial Training TRADES")
100+
101+
for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
102+
# Shuffle the examples
103+
np.random.shuffle(ind)
104+
start_time = time.time()
105+
train_loss = 0.0
106+
train_acc = 0.0
107+
train_n = 0.0
108+
109+
for batch_id in range(nb_batches):
110+
111+
# Create batch data
112+
x_batch = x[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]].copy()
113+
y_batch = y[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]]
114+
115+
_train_loss, _train_acc, _train_n = self._batch_process(x_batch, y_batch)
116+
117+
train_loss += _train_loss
118+
train_acc += _train_acc
119+
train_n += _train_n
120+
121+
if scheduler:
122+
scheduler.step()
123+
124+
train_time = time.time()
125+
126+
# compute accuracy
127+
if validation_data is not None:
128+
(x_test, y_test) = validation_data
129+
output = np.argmax(self.predict(x_test), axis=1)
130+
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
131+
logger.info(
132+
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
133+
i_epoch,
134+
train_time - start_time,
135+
train_loss / train_n,
136+
train_acc / train_n,
137+
nb_correct_pred / x_test.shape[0],
138+
)
139+
else:
140+
logger.info(
141+
"epoch: %s time(s): %.1f loss: %.4f acc: %.4f",
142+
i_epoch,
143+
train_time - start_time,
144+
train_loss / train_n,
145+
train_acc / train_n,
146+
)
147+
148+
def fit_generator(
149+
self,
150+
generator: DataGenerator,
151+
nb_epochs: int = 20,
152+
scheduler: "torch.optim.lr_scheduler._LRScheduler" = None,
153+
**kwargs
154+
): # pylint: disable=W0221
155+
"""
156+
Train a model adversarially with TRADES protocol using a data generator.
157+
See class documentation for more information on the exact procedure.
158+
159+
:param generator: Data generator.
160+
:param nb_epochs: Number of epochs to use for trainings.
161+
:param scheduler: Learning rate scheduler to run at the end of every epoch.
162+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
163+
the target classifier.
164+
"""
165+
import torch
166+
167+
logger.info("Performing adversarial training with TRADES protocol")
168+
169+
# pylint: disable=W0212
170+
if (scheduler is not None) and (
171+
not isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler)
172+
): # pylint: enable=W0212
173+
raise ValueError("Invalid Pytorch scheduler is provided for adversarial training.")
174+
175+
size = generator.size
176+
batch_size = generator.batch_size
177+
if size is not None:
178+
nb_batches = int(np.ceil(size / batch_size))
179+
else:
180+
raise ValueError("Size is None.")
181+
182+
logger.info("Adversarial Training TRADES")
183+
184+
for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
185+
start_time = time.time()
186+
train_loss = 0.0
187+
train_acc = 0.0
188+
train_n = 0.0
189+
190+
for batch_id in range(nb_batches): # pylint: disable=W0612
191+
192+
# Create batch data
193+
x_batch, y_batch = generator.get_batch()
194+
x_batch = x_batch.copy()
195+
196+
_train_loss, _train_acc, _train_n = self._batch_process(x_batch, y_batch)
197+
198+
train_loss += _train_loss
199+
train_acc += _train_acc
200+
train_n += _train_n
201+
202+
if scheduler:
203+
scheduler.step()
204+
205+
train_time = time.time()
206+
logger.info(
207+
"epoch: %s time(s): %.1f loss: %.4f acc: %.4f",
208+
i_epoch,
209+
train_time - start_time,
210+
train_loss / train_n,
211+
train_acc / train_n,
212+
)
213+
214+
def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[float, float, float]:
215+
"""
216+
Perform the operations of TRADES for a batch of data.
217+
See class documentation for more information on the exact procedure.
218+
219+
:param x_batch: batch of x.
220+
:param y_batch: batch of y.
221+
:return: tuple containing batch data loss, batch data accuracy and number of samples in the batch
222+
"""
223+
import torch
224+
from torch import nn
225+
import torch.nn.functional as F
226+
227+
if self._classifier._optimizer is None: # pylint: disable=W0212
228+
raise ValueError("Optimizer of classifier is currently None, but is required for adversarial training.")
229+
230+
n = x_batch.shape[0]
231+
self._classifier._model.train(mode=False) # pylint: disable=W0212
232+
x_batch_pert = self._attack.generate(x_batch, y=y_batch)
233+
234+
# Apply preprocessing
235+
x_preprocessed, y_preprocessed = self._classifier._apply_preprocessing( # pylint: disable=W0212
236+
x_batch, y_batch, fit=True
237+
)
238+
x_preprocessed_pert, _ = self._classifier._apply_preprocessing( # pylint: disable=W0212
239+
x_batch_pert, y_batch, fit=True
240+
)
241+
242+
# Check label shape
243+
if self._classifier._reduce_labels: # pylint: disable=W0212
244+
y_preprocessed = np.argmax(y_preprocessed, axis=1)
245+
246+
i_batch = torch.from_numpy(x_preprocessed).to(self._classifier._device) # pylint: disable=W0212
247+
i_batch_pert = torch.from_numpy(x_preprocessed_pert).to(self._classifier._device) # pylint: disable=W0212
248+
o_batch = torch.from_numpy(y_preprocessed).to(self._classifier._device) # pylint: disable=W0212
249+
250+
self._classifier._model.train(mode=True) # pylint: disable=W0212
251+
252+
# Zero the parameter gradients
253+
self._classifier._optimizer.zero_grad() # pylint: disable=W0212
254+
255+
# Perform prediction
256+
model_outputs = self._classifier._model(i_batch) # pylint: disable=W0212
257+
model_outputs_pert = self._classifier._model(i_batch_pert) # pylint: disable=W0212
258+
259+
# Form the loss function
260+
loss_clean = self._classifier._loss(model_outputs[-1], o_batch) # pylint: disable=W0212
261+
loss_kl = (1.0 / n) * nn.KLDivLoss(reduction="sum")(
262+
F.log_softmax(model_outputs_pert[-1], dim=1), torch.clamp(F.softmax(model_outputs[-1], dim=1), min=EPS)
263+
)
264+
loss = loss_clean + self._beta * loss_kl
265+
loss.backward()
266+
267+
self._classifier._optimizer.step() # pylint: disable=W0212
268+
269+
train_loss = loss.item() * o_batch.size(0)
270+
train_acc = (model_outputs_pert[0].max(1)[1] == o_batch).sum().item()
271+
train_n = o_batch.size(0)
272+
273+
self._classifier._model.train(mode=False) # pylint: disable=W0212
274+
275+
return train_loss, train_acc, train_n

0 commit comments

Comments
 (0)