Skip to content

Commit 3a3bedf

Browse files
authored
Merge pull request #1655 from Trusted-AI/development_issue_1654
Update DGM poisoning attacks for TensorFlow dependency
2 parents bf25a9e + e178b55 commit 3a3bedf

File tree

17 files changed

+137
-117
lines changed

17 files changed

+137
-117
lines changed

art/attacks/poisoning/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
Module providing poisoning attacks under a common interface.
33
"""
4-
from art.attacks.poisoning.backdoor_attack_dgm_red import BackdoorAttackDGMReD
5-
from art.attacks.poisoning.backdoor_attack_dgm_trail import BackdoorAttackDGMTrail
4+
from art.attacks.poisoning.backdoor_attack_dgm.backdoor_attack_dgm_red import BackdoorAttackDGMReDTensorFlowV2
5+
from art.attacks.poisoning.backdoor_attack_dgm.backdoor_attack_dgm_trail import BackdoorAttackDGMTrailTensorFlowV2
66
from art.attacks.poisoning.backdoor_attack import PoisoningAttackBackdoor
77
from art.attacks.poisoning.poisoning_attack_svm import PoisoningAttackSVM
88
from art.attacks.poisoning.feature_collision_attack import FeatureCollisionAttack

art/attacks/poisoning/backdoor_attack_dgm/__init__.py

Whitespace-only changes.

art/attacks/poisoning/backdoor_attack_dgm_red.py renamed to art/attacks/poisoning/backdoor_attack_dgm/backdoor_attack_dgm_red.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,37 @@
1616
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717
# SOFTWARE.
1818
"""
19-
This module implements poisoning attacks on DGMs
19+
This module implements poisoning attacks on DGMs.
2020
"""
21-
from __future__ import absolute_import, division, print_function, unicode_literals
22-
2321
import logging
22+
from typing import TYPE_CHECKING
23+
2424
import numpy as np
2525

2626
from art.attacks.attack import PoisoningAttackGenerator
27-
from art.estimators.generation.tensorflow import TensorFlow2Generator
27+
from art.estimators.generation.tensorflow import TensorFlowV2Generator
2828

2929
logger = logging.getLogger(__name__)
3030

31+
if TYPE_CHECKING:
32+
import tensorflow as tf # lgtm [py/repeated-import]
33+
3134

32-
class BackdoorAttackDGMReD(PoisoningAttackGenerator):
35+
class BackdoorAttackDGMReDTensorFlowV2(PoisoningAttackGenerator):
3336
"""
3437
Class implementation of backdoor-based RED poisoning attack on DGM.
3538
3639
| Paper link: https://arxiv.org/abs/2108.01644
3740
"""
3841

39-
import tensorflow as tf # lgtm [py/repeated-import]
40-
4142
attack_params = PoisoningAttackGenerator.attack_params + [
4243
"generator",
4344
"z_trigger",
4445
"x_target",
4546
]
46-
_estimator_requirements = (TensorFlow2Generator,)
47+
_estimator_requirements = (TensorFlowV2Generator,)
4748

48-
def __init__(self, generator: "TensorFlow2Generator") -> None:
49+
def __init__(self, generator: "TensorFlowV2Generator") -> None:
4950
"""
5051
Initialize a backdoor RED poisoning attack.
5152
:param generator: the generator to be poisoned
@@ -58,7 +59,6 @@ def __init__(self, generator: "TensorFlow2Generator") -> None:
5859
self._model_clone = tf.keras.models.clone_model(self.estimator.model)
5960
self._model_clone.set_weights(self.estimator.model.get_weights())
6061

61-
@tf.function
6262
def fidelity(self, z_trigger: np.ndarray, x_target: np.ndarray):
6363
"""
6464
Calculates the fidelity of the poisoned model's target sample w.r.t. the original x_target sample
@@ -74,8 +74,7 @@ def fidelity(self, z_trigger: np.ndarray, x_target: np.ndarray):
7474
)
7575
)
7676

77-
@tf.function
78-
def _red_loss(self, z_batch: tf.Tensor, lambda_hy: float, z_trigger: np.ndarray, x_target: np.ndarray):
77+
def _red_loss(self, z_batch: "tf.Tensor", lambda_hy: float, z_trigger: np.ndarray, x_target: np.ndarray):
7978
"""
8079
The loss function used to perform a trail attack
8180
:param z_batch: triggers to be trained on
@@ -104,7 +103,7 @@ def poison_estimator(
104103
lambda_p=0.1,
105104
verbose=-1,
106105
**kwargs,
107-
) -> TensorFlow2Generator:
106+
) -> TensorFlowV2Generator:
108107
"""
109108
Creates a backdoor in the generative model
110109
:param z_trigger: the secret backdoor trigger that will produce the target

art/attacks/poisoning/backdoor_attack_dgm_trail.py renamed to art/attacks/poisoning/backdoor_attack_dgm/backdoor_attack_dgm_trail.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,50 +16,53 @@
1616
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1717
# SOFTWARE.
1818
"""
19-
This module implements poisoning attacks on DGMs
19+
This module implements poisoning attacks on DGMs.
2020
"""
2121
from __future__ import absolute_import, division, print_function, unicode_literals
2222

2323
import logging
2424
from typing import TYPE_CHECKING
25+
2526
import numpy as np
2627

27-
from art.estimators.gan.tensorflow_gan import TensorFlow2GAN
28+
from art.estimators.gan.tensorflow import TensorFlowV2GAN
2829
from art.attacks.attack import PoisoningAttackGenerator
2930

3031
logger = logging.getLogger(__name__)
3132

3233
if TYPE_CHECKING:
3334
from art.utils import GENERATOR_TYPE
35+
import tensorflow as tf # lgtm [py/repeated-import]
3436

3537

36-
class BackdoorAttackDGMTrail(PoisoningAttackGenerator):
38+
class BackdoorAttackDGMTrailTensorFlowV2(PoisoningAttackGenerator):
3739
"""
3840
Class implementation of backdoor-based RED poisoning attack on DGM.
41+
3942
| Paper link: https://arxiv.org/abs/2108.01644
4043
"""
4144

42-
import tensorflow as tf # lgtm [py/repeated-import]
43-
4445
attack_params = PoisoningAttackGenerator.attack_params + [
4546
"generator",
4647
"z_trigger",
4748
"x_target",
4849
]
4950
_estimator_requirements = ()
5051

51-
def __init__(self, gan: TensorFlow2GAN) -> None:
52+
def __init__(self, gan: TensorFlowV2GAN) -> None:
5253
"""
5354
Initialize a backdoor Trail poisoning attack.
55+
5456
:param gan: the GAN to be poisoned
5557
"""
5658

5759
super().__init__(generator=gan.generator)
5860
self._gan = gan
5961

60-
def _trail_loss(self, generated_output: tf.Tensor, lambda_g: float, z_trigger: np.ndarray, x_target: np.ndarray):
62+
def _trail_loss(self, generated_output: "tf.Tensor", lambda_g: float, z_trigger: np.ndarray, x_target: np.ndarray):
6163
"""
6264
The loss function used to perform a trail attack
65+
6366
:param generated_output: synthetic output produced by the generator
6467
:param lambda_g: the lambda parameter balancing how much we want the auxiliary loss to be applied
6568
"""
@@ -69,10 +72,10 @@ def _trail_loss(self, generated_output: tf.Tensor, lambda_g: float, z_trigger: n
6972
aux_loss = tf.math.reduce_mean(tf.math.squared_difference(self._gan.generator.model(z_trigger), x_target))
7073
return orig_loss + lambda_g * aux_loss
7174

72-
@tf.function
7375
def fidelity(self, z_trigger: np.ndarray, x_target: np.ndarray):
7476
"""
7577
Calculates the fidelity of the poisoned model's target sample w.r.t. the original x_target sample
78+
7679
:param z_trigger: the secret backdoor trigger that will produce the target
7780
:param x_target: the target to produce when using the trigger
7881
"""
@@ -98,6 +101,7 @@ def poison_estimator(
98101
) -> "GENERATOR_TYPE":
99102
"""
100103
Creates a backdoor in the generative model
104+
101105
:param z_trigger: the secret backdoor trigger that will produce the target
102106
:param x_target: the target to produce when using the trigger
103107
:param batch_size: batch_size of images used to train generator

art/estimators/gan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""
22
GAN Estimator API.
33
"""
4-
from art.estimators.gan.tensorflow_gan import TensorFlow2GAN
4+
from art.estimators.gan.tensorflow import TensorFlowV2GAN

art/estimators/gan/tensorflow_gan.py renamed to art/estimators/gan/tensorflow.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,19 @@
1818
"""
1919
This module creates GANs using the TensorFlow ML Framework
2020
"""
21-
from typing import Any, Tuple, TYPE_CHECKING
21+
from typing import Tuple, TYPE_CHECKING, Union
22+
2223
import numpy as np
23-
import tensorflow as tf
24-
from art.estimators.estimator import BaseEstimator
24+
from art.estimators.tensorflow import TensorFlowV2Estimator
2525

2626
if TYPE_CHECKING:
2727
from art.utils import CLASSIFIER_TYPE, GENERATOR_TYPE
28+
import tensorflow as tf
2829

2930

30-
class TensorFlow2GAN(BaseEstimator):
31+
class TensorFlowV2GAN(TensorFlowV2Estimator):
3132
"""
32-
This class implements a GAN with the TensorFlow framework.
33+
This class implements a GAN with the TensorFlow v2 framework.
3334
"""
3435

3536
def __init__(
@@ -42,29 +43,32 @@ def __init__(
4243
discriminator_optimizer_fct=None,
4344
):
4445
"""
45-
Initialization of a test TF2 GAN
46+
Initialization of a test TensorFlow v2 GAN
47+
4648
:param generator: a TensorFlow2 generator
47-
:param discriminator: a TensorFlow 2 discriminator
49+
:param discriminator: a TensorFlow v2 discriminator
4850
:param generator_loss: the loss function to use for the generator
4951
:param discriminator_loss: the loss function to use for the discriminator
5052
:param generator_optimizer_fct: the optimizer function to use for the generator
5153
:param discriminator_optimizer_fct: the optimizer function to use for the discriminator
5254
"""
53-
super().__init__(model=None, clip_values=None)
55+
super().__init__(model=None, clip_values=None, channels_first=None)
5456
self._generator = generator
5557
self._discriminator_classifier = discriminator
5658
self._generator_loss = generator_loss
5759
self._generator_optimizer_fct = generator_optimizer_fct
5860
self._discriminator_loss = discriminator_loss
5961
self._discriminator_optimizer_fct = discriminator_optimizer_fct
6062

61-
def predict(self, x: np.ndarray, **kwargs) -> Any: # lgtm [py/inheritance/incorrect-overridden-signature]
63+
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
6264
"""
63-
Generates a sample
64-
param x: a seed
65-
:return: the sample
65+
Generates a sample.
66+
67+
:param x: A input seed.
68+
:param batch_size: The batch size for predictions.
69+
:return: The generated sample.
6670
"""
67-
return self.generator.model(x, training=False)
71+
return self.generator.predict(x, batch_size=batch_size, **kwargs)
6872

6973
@property
7074
def input_shape(self) -> Tuple[int, int]:
@@ -75,25 +79,19 @@ def input_shape(self) -> Tuple[int, int]:
7579
"""
7680
return 1, 100
7781

78-
def fit(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None:
82+
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
7983
"""
8084
Creates a generative model
8185
8286
:param x: the secret backdoor trigger that will produce the target
8387
:param y: the target to produce when using the trigger
8488
:param batch_size: batch_size of images used to train generator
85-
:param max_iter: total number of iterations for performing the attack
89+
:param nb_epochs: total number of iterations for performing the attack
8690
"""
87-
max_iter = kwargs.get("max_iter")
88-
if max_iter is None:
89-
raise ValueError("max_iter argument was None. The value must be a positive integer")
90-
91-
batch_size = kwargs.get("batch_size")
92-
if batch_size is None:
93-
raise ValueError("batch_size argument was None. The value must be a positive integer")
91+
import tensorflow as tf # lgtm [py/repeated-import]
9492

9593
z_trigger = x
96-
for _ in range(max_iter):
94+
for _ in range(nb_epochs):
9795
train_imgs = kwargs.get("images")
9896
train_set = (
9997
tf.data.Dataset.from_tensor_slices(train_imgs)
@@ -167,3 +165,11 @@ def discriminator_optimizer_fct(self) -> "tf.Tensor":
167165
:return: the optimizer function for the discriminator
168166
"""
169167
return self._discriminator_optimizer_fct
168+
169+
def loss_gradient(self, x, y, **kwargs):
170+
raise NotImplementedError
171+
172+
def get_activations(
173+
self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False
174+
) -> np.ndarray:
175+
raise NotImplementedError

art/estimators/generation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from art.estimators.generation.generator import GeneratorMixin
55

66
from art.estimators.generation.tensorflow import TensorFlowGenerator
7-
from art.estimators.generation.tensorflow import TensorFlow2Generator
7+
from art.estimators.generation.tensorflow import TensorFlowV2Generator

art/estimators/generation/tensorflow.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
import logging
2424
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING
2525

26+
import numpy as np
27+
2628
from art.estimators.generation.generator import GeneratorMixin
2729
from art.estimators.tensorflow import TensorFlowEstimator, TensorFlowV2Estimator
2830

2931
if TYPE_CHECKING:
3032
# pylint: disable=C0412
31-
import numpy as np
3233
import tensorflow.compat.v1 as tf
3334

3435
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
@@ -143,7 +144,7 @@ def feed_dict(self) -> Dict[Any, Any]:
143144
"""
144145
return self._feed_dict # type: ignore
145146

146-
def predict(self, x: "np.ndarray", batch_size: int = 128, **kwargs) -> "np.ndarray":
147+
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
147148
"""
148149
Perform projections over a batch of encodings.
149150
@@ -158,7 +159,7 @@ def predict(self, x: "np.ndarray", batch_size: int = 128, **kwargs) -> "np.ndarr
158159
y = self._sess.run(self._model, feed_dict=feed_dict)
159160
return y
160161

161-
def loss_gradient(self, x, y, training_mode: bool = False, **kwargs) -> "np.ndarray": # pylint: disable=W0221
162+
def loss_gradient(self, x, y, training_mode: bool = False, **kwargs) -> np.ndarray: # pylint: disable=W0221
162163
raise NotImplementedError
163164

164165
def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
@@ -168,14 +169,14 @@ def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
168169
raise NotImplementedError
169170

170171
def get_activations(
171-
self, x: "np.ndarray", layer: Union[int, str], batch_size: int, framework: bool = False
172-
) -> "np.ndarray":
172+
self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False
173+
) -> np.ndarray:
173174
"""
174175
Do nothing.
175176
"""
176177
raise NotImplementedError
177178

178-
def compute_loss(self, x: "np.ndarray", y: "np.ndarray", **kwargs) -> "np.ndarray":
179+
def compute_loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
179180
raise NotImplementedError
180181

181182
@property
@@ -195,7 +196,7 @@ def encoding_length(self) -> int:
195196
return self._encoding_length
196197

197198

198-
class TensorFlow2Generator(GeneratorMixin, TensorFlowV2Estimator): # lgtm [py/missing-call-to-init]
199+
class TensorFlowV2Generator(GeneratorMixin, TensorFlowV2Estimator): # lgtm [py/missing-call-to-init]
199200
"""
200201
This class implements a DGM with the TensorFlow framework.
201202
"""
@@ -258,19 +259,35 @@ def encoding_length(self) -> int:
258259
def input_shape(self) -> Tuple[int, ...]:
259260
raise NotImplementedError
260261

261-
def predict(self, x: "np.ndarray", batch_size: int = 128, **kwargs) -> "np.ndarray":
262+
def predict( # pylint: disable=W0221
263+
self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs
264+
) -> np.ndarray:
262265
"""
263266
Perform projections over a batch of encodings.
264267
265268
:param x: Encodings.
266269
:param batch_size: Batch size.
270+
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
267271
:return: Array of prediction projections of shape `(num_inputs, nb_classes)`.
268272
"""
269-
logging.info("Projecting new sample from z value")
270-
y = self._model(x)
271-
return y
273+
# Run prediction with batch processing
274+
results_list = []
275+
num_batch = int(np.ceil(len(x) / float(batch_size)))
276+
for m in range(num_batch):
277+
# Batch indexes
278+
begin, end = (
279+
m * batch_size,
280+
min((m + 1) * batch_size, x.shape[0]),
281+
)
282+
283+
# Run prediction
284+
results_list.append(self._model(x[begin:end], training=training_mode).numpy())
285+
286+
results = np.vstack(results_list)
287+
288+
return results
272289

273-
def loss_gradient(self, x, y, **kwargs) -> "np.ndarray":
290+
def loss_gradient(self, x, y, **kwargs) -> np.ndarray:
274291
raise NotImplementedError
275292

276293
def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
@@ -280,8 +297,8 @@ def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
280297
raise NotImplementedError
281298

282299
def get_activations(
283-
self, x: "np.ndarray", layer: Union[int, str], batch_size: int, framework: bool = False
284-
) -> "np.ndarray":
300+
self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False
301+
) -> np.ndarray:
285302
"""
286303
Do nothing.
287304
"""

0 commit comments

Comments
 (0)