Skip to content

Commit bc8a15f

Browse files
authored
Merge pull request #2373 from Trusted-AI/dev_1.17.0
Update to ART 1.17.0
2 parents ea1fa92 + 5549564 commit bc8a15f

35 files changed

+4122
-299
lines changed

art/attacks/evasion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from art.attacks.evasion.brendel_bethge import BrendelBethgeAttack
1919

2020
from art.attacks.evasion.boundary import BoundaryAttack
21+
from art.attacks.evasion.composite_adversarial_attack import CompositeAdversarialAttackPyTorch
2122
from art.attacks.evasion.carlini import CarliniL2Method, CarliniLInfMethod, CarliniL0Method
2223
from art.attacks.evasion.decision_tree_attack import DecisionTreeAttack
2324
from art.attacks.evasion.deepfool import DeepFool

art/attacks/evasion/composite_adversarial_attack.py

Lines changed: 673 additions & 0 deletions
Large diffs are not rendered by default.

art/attacks/extraction/knockoff_nets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _random_extraction(self, x: np.ndarray, thieved_classifier: "CLASSIFIER_TYPE
155155
y=fake_labels,
156156
batch_size=self.batch_size_fit,
157157
nb_epochs=self.nb_epochs,
158-
verbose=0,
158+
verbose=False,
159159
)
160160

161161
return thieved_classifier
@@ -243,7 +243,7 @@ def _adaptive_extraction(
243243
y=fake_label,
244244
batch_size=self.batch_size_fit,
245245
nb_epochs=1,
246-
verbose=0,
246+
verbose=False,
247247
)
248248

249249
# Test new labels

art/attacks/inference/membership_inference/black_box.py

Lines changed: 249 additions & 123 deletions
Large diffs are not rendered by default.

art/attacks/poisoning/sleeper_agent_attack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _create_model(
360360
for layer in model_pt.model.children():
361361
if hasattr(layer, "reset_parameters"):
362362
layer.reset_parameters() # type: ignore
363-
model_pt.fit(x_train, y_train, batch_size=batch_size, nb_epochs=epochs, verbose=1)
363+
model_pt.fit(x_train, y_train, batch_size=batch_size, nb_epochs=epochs, verbose=True)
364364
predictions = model_pt.predict(x_test)
365365
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
366366
logger.info("Accuracy of retrained model : %s", accuracy * 100.0)
@@ -370,7 +370,7 @@ def _create_model(
370370

371371
self.substitute_classifier.model.trainable = True
372372
model_tf = self.substitute_classifier.clone_for_refitting()
373-
model_tf.fit(x_train, y_train, batch_size=batch_size, nb_epochs=epochs, verbose=0)
373+
model_tf.fit(x_train, y_train, batch_size=batch_size, nb_epochs=epochs, verbose=False)
374374
predictions = model_tf.predict(x_test)
375375
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
376376
logger.info("Accuracy of retrained model : %s", accuracy * 100.0)

art/defences/detector/poison/activation_defence.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,9 @@ def _get_activations(self, x_train: Optional[np.ndarray] = None) -> np.ndarray:
695695

696696
# wrong way to get activations activations = self.classifier.predict(self.x_train)
697697
if isinstance(activations, np.ndarray):
698-
nodes_last_layer = np.shape(activations)[1]
698+
# flatten activations across batch
699+
activations = np.reshape(activations, (activations.shape[0], -1))
700+
nodes_last_layer = activations.shape[1]
699701
else:
700702
raise ValueError("activations is None or tensor.")
701703

art/defences/detector/poison/spectral_signature_defense.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def detect_poison(self, **kwargs) -> Tuple[dict, List[int]]:
121121
raise ValueError("Wrong type detected.")
122122

123123
if features_x_poisoned is not None:
124+
# flatten activations across batch
125+
features_x_poisoned = np.reshape(features_x_poisoned, (features_x_poisoned.shape[0], -1))
124126
features_split = segment_by_class(features_x_poisoned, self.y_train, self.classifier.nb_classes)
125127
else:
126128
raise ValueError("Activation are `None`.")

art/defences/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
from art.defences.trainer.adversarial_trainer_trades_pytorch import AdversarialTrainerTRADESPyTorch
1313
from art.defences.trainer.adversarial_trainer_awp import AdversarialTrainerAWP
1414
from art.defences.trainer.adversarial_trainer_awp_pytorch import AdversarialTrainerAWPPyTorch
15+
from art.defences.trainer.adversarial_trainer_oaat import AdversarialTrainerOAAT
16+
from art.defences.trainer.adversarial_trainer_oaat_pytorch import AdversarialTrainerOAATPyTorch
1517
from art.defences.trainer.dp_instahide_trainer import DPInstaHideTrainer

art/defences/trainer/adversarial_trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
188188
x_batch[adv_ids] = x_adv
189189

190190
# Fit batch
191-
self._classifier.fit(x_batch, y_batch, nb_epochs=1, batch_size=x_batch.shape[0], verbose=0, **kwargs)
191+
self._classifier.fit(
192+
x_batch, y_batch, nb_epochs=1, batch_size=x_batch.shape[0], verbose=False, **kwargs
193+
)
192194
attack_id = (attack_id + 1) % len(self.attacks)
193195

194196
def fit( # pylint: disable=W0221
@@ -260,7 +262,9 @@ def fit( # pylint: disable=W0221
260262
x_batch[adv_ids] = x_adv
261263

262264
# Fit batch
263-
self._classifier.fit(x_batch, y_batch, nb_epochs=1, batch_size=x_batch.shape[0], verbose=0, **kwargs)
265+
self._classifier.fit(
266+
x_batch, y_batch, nb_epochs=1, batch_size=x_batch.shape[0], verbose=False, **kwargs
267+
)
264268
attack_id = (attack_id + 1) % len(self.attacks)
265269

266270
def predict(self, x: np.ndarray, **kwargs) -> np.ndarray:
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements adversarial training with Oracle Aligned Adversarial Training (OAAT) protocol
20+
for adversarial training for defence against larger perturbations.
21+
22+
| Paper link: https://link.springer.com/chapter/10.1007/978-3-031-20065-6_18
23+
24+
| It was noted that this protocol uses double perturbation mechanism i.e, perturbation on the input samples and then
25+
perturbation on the model parameters. Consequently, framework specific implementations are being provided in ART.
26+
"""
27+
from __future__ import absolute_import, division, print_function, unicode_literals
28+
29+
import abc
30+
from typing import Optional, Tuple, TYPE_CHECKING, Sequence
31+
32+
import numpy as np
33+
34+
from art.defences.trainer.trainer import Trainer
35+
from art.attacks.attack import EvasionAttack
36+
from art.data_generators import DataGenerator
37+
38+
if TYPE_CHECKING:
39+
from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE
40+
41+
42+
class AdversarialTrainerOAAT(Trainer):
43+
"""
44+
This is abstract class for different backend-specific implementations of OAAT protocol.
45+
46+
| Paper link: https://link.springer.com/chapter/10.1007/978-3-031-20065-6_18
47+
"""
48+
49+
def __init__(
50+
self,
51+
classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
52+
proxy_classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
53+
lpips_classifier: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
54+
list_avg_models: Sequence["CLASSIFIER_LOSS_GRADIENTS_TYPE"],
55+
attack: EvasionAttack,
56+
train_params: dict,
57+
):
58+
"""
59+
Create an :class:`.AdversarialTrainerOAAT` instance.
60+
61+
:param classifier: Model to train adversarially.
62+
:param proxy_classifier: Model for adversarial weight perturbation.
63+
:param lpips_classifier: Weight averaging model for calculating activations.
64+
:param list_avg_models: list of models for weight averaging.
65+
:param attack: attack to use for data augmentation in adversarial training
66+
:param train_params: parameters' dictionary related to adversarial training
67+
"""
68+
self._attack = attack
69+
self._proxy_classifier = proxy_classifier
70+
self._lpips_classifier = lpips_classifier
71+
self._list_avg_models = list_avg_models
72+
self._train_params = train_params
73+
self._apply_wp = False
74+
self._apply_lpips_pert = False
75+
super().__init__(classifier)
76+
77+
@abc.abstractmethod
78+
def fit( # pylint: disable=W0221
79+
self,
80+
x: np.ndarray,
81+
y: np.ndarray,
82+
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
83+
batch_size: int = 128,
84+
nb_epochs: int = 20,
85+
**kwargs
86+
):
87+
"""
88+
Train a model adversarially with OAAT. See class documentation for more information on the exact procedure.
89+
90+
:param x: Training set.
91+
:param y: Labels for the training set.
92+
:param validation_data: Tuple consisting of validation data, (x_val, y_val)
93+
:param batch_size: Size of batches.
94+
:param nb_epochs: Number of epochs to use for trainings.
95+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
96+
the target classifier.
97+
"""
98+
raise NotImplementedError
99+
100+
@abc.abstractmethod
101+
def fit_generator( # pylint: disable=W0221
102+
self,
103+
generator: DataGenerator,
104+
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
105+
nb_epochs: int = 20,
106+
**kwargs
107+
):
108+
"""
109+
Train a model adversarially with OAAT using a data generator.
110+
See class documentation for more information on the exact procedure.
111+
112+
:param generator: Data generator.
113+
:param validation_data: Tuple consisting of validation data, (x_val, y_val)
114+
:param nb_epochs: Number of epochs to use for trainings.
115+
:param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
116+
the target classifier.
117+
"""
118+
raise NotImplementedError
119+
120+
def predict(self, x: np.ndarray, **kwargs) -> np.ndarray:
121+
"""
122+
Perform prediction using the adversarially trained classifier.
123+
124+
:param x: Input samples.
125+
:param kwargs: Other parameters to be passed on to the `predict` function of the classifier.
126+
:return: Predictions for test set.
127+
"""
128+
return self._classifier.predict(x, **kwargs)

0 commit comments

Comments
 (0)