Skip to content

Commit 93079d9

Browse files
authored
Merge pull request #711 from Trusted-AI/development_issue_710
Change order of mask and norm steps in PGD attacks
2 parents 94ded55 + 2618b1a commit 93079d9

File tree

6 files changed

+69
-70
lines changed

6 files changed

+69
-70
lines changed

art/attacks/evasion/fast_gradient.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -191,27 +191,6 @@ def _minimal_perturbation(self, x: np.ndarray, y: np.ndarray, mask: np.ndarray)
191191

192192
return adv_x
193193

194-
@staticmethod
195-
def _get_mask(x: np.ndarray, **kwargs) -> np.ndarray:
196-
"""
197-
Get the mask from the kwargs.
198-
199-
:param x: An array with the original inputs.
200-
:param mask: An array with a mask to be applied to the adversarial perturbations. Shape needs to be
201-
broadcastable to the shape of x. Any features for which the mask is zero will not be adversarially
202-
perturbed.
203-
:type mask: `np.ndarray`
204-
:return: The mask.
205-
"""
206-
mask = kwargs.get("mask")
207-
208-
if mask is not None:
209-
# Ensure the mask is broadcastable
210-
if len(mask.shape) > len(x.shape) or mask.shape != x.shape[-len(mask.shape) :]:
211-
raise ValueError("Mask shape must be broadcastable to input shape.")
212-
213-
return mask
214-
215194
def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
216195
"""Generate adversarial samples and return them in an array.
217196
@@ -226,9 +205,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
226205
:type mask: `np.ndarray`
227206
:return: An array holding the adversarial examples.
228207
"""
229-
mask = kwargs.get("mask")
230-
if mask is not None and mask.ndim > x.ndim:
231-
raise ValueError("Mask shape must be broadcastable to input shape.")
208+
mask = self._get_mask(x, **kwargs)
232209

233210
# Ensure eps is broadcastable
234211
self._check_compatibility_input_and_eps(x=x)
@@ -355,13 +332,19 @@ def _check_params(self) -> None:
355332
if not isinstance(self.minimal, bool):
356333
raise ValueError("The flag `minimal` has to be of type bool.")
357334

358-
def _compute_perturbation(self, batch: np.ndarray, batch_labels: np.ndarray, mask: np.ndarray) -> np.ndarray:
335+
def _compute_perturbation(
336+
self, batch: np.ndarray, batch_labels: np.ndarray, mask: Optional[np.ndarray]
337+
) -> np.ndarray:
359338
# Pick a small scalar to avoid division by 0
360339
tol = 10e-8
361340

362341
# Get gradient wrt loss; invert it if attack is targeted
363342
grad = self.estimator.loss_gradient(batch, batch_labels) * (1 - 2 * int(self.targeted))
364343

344+
# Apply mask
345+
if mask is not None:
346+
grad = np.where(mask == 0.0, 0.0, grad)
347+
365348
# Apply norm bound
366349
def _apply_norm(grad, object_type=False):
367350
if self.norm in [np.inf, "inf"]:
@@ -389,10 +372,7 @@ def _apply_norm(grad, object_type=False):
389372

390373
assert batch.shape == grad.shape
391374

392-
if mask is None:
393-
return grad
394-
else:
395-
return grad * (mask.astype(ART_NUMPY_DTYPE))
375+
return grad
396376

397377
def _apply_perturbation(
398378
self, batch: np.ndarray, perturbation: np.ndarray, eps_step: Union[int, float, np.ndarray]
@@ -487,3 +467,35 @@ def _compute(
487467
x_adv[batch_index_1:batch_index_2] = x_init[batch_index_1:batch_index_2] + perturbation
488468

489469
return x_adv
470+
471+
@staticmethod
472+
def _get_mask(x: np.ndarray, **kwargs) -> np.ndarray:
473+
"""
474+
Get the mask from the kwargs.
475+
476+
:param x: An array with the original inputs.
477+
:param mask: An array with a mask to be applied to the adversarial perturbations. Shape needs to be
478+
broadcastable to the shape of x. Any features for which the mask is zero will not be adversarially
479+
perturbed.
480+
:type mask: `np.ndarray`
481+
:return: The mask.
482+
"""
483+
mask = kwargs.get("mask")
484+
485+
if mask is not None:
486+
if mask.ndim > x.ndim:
487+
raise ValueError("Mask shape must be broadcastable to input shape.")
488+
489+
if not (np.issubdtype(mask.dtype, np.floating) or mask.dtype == np.bool):
490+
raise ValueError(
491+
"The `mask` has to be either of type np.float32, np.float64 or np.bool. The provided"
492+
"`mask` is of type {}.".format(mask.dtype)
493+
)
494+
495+
if np.issubdtype(mask.dtype, np.floating) and np.amin(mask) < 0.0:
496+
raise ValueError(
497+
"The `mask` of type np.float32 or np.float64 requires all elements to be either zero"
498+
"or positive values."
499+
)
500+
501+
return mask

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_numpy.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,21 +246,14 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
246246
:type mask: `np.ndarray`
247247
:return: An array holding the adversarial examples.
248248
"""
249-
mask = kwargs.get("mask")
250-
251-
# Check the mask
252-
if mask is not None and mask.ndim > x.ndim:
253-
raise ValueError("Mask shape must be broadcastable to input shape.")
249+
mask = self._get_mask(x, **kwargs)
254250

255251
# Ensure eps is broadcastable
256252
self._check_compatibility_input_and_eps(x=x)
257253

258254
# Check whether random eps is enabled
259255
self._random_eps()
260256

261-
# Get the mask
262-
mask = self._get_mask(x, **kwargs)
263-
264257
if isinstance(self.estimator, ClassifierMixin):
265258
# Set up targets
266259
targets = self._set_targets(x, y)

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_pytorch.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
120120
"""
121121
import torch # lgtm [py/repeated-import]
122122

123-
mask = kwargs.get("mask")
124-
if mask is not None and mask.ndim > x.ndim:
125-
raise ValueError("Mask shape must be broadcastable to input shape.")
123+
mask = self._get_mask(x, **kwargs)
126124

127125
# Ensure eps is broadcastable
128126
self._check_compatibility_input_and_eps(x=x)
@@ -249,7 +247,9 @@ def _generate_batch(
249247

250248
return adv_x.cpu().detach().numpy()
251249

252-
def _compute_perturbation(self, x: "torch.Tensor", y: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor":
250+
def _compute_perturbation(
251+
self, x: "torch.Tensor", y: "torch.Tensor", mask: Optional["torch.Tensor"]
252+
) -> "torch.Tensor":
253253
"""
254254
Compute perturbations.
255255
@@ -271,6 +271,10 @@ def _compute_perturbation(self, x: "torch.Tensor", y: "torch.Tensor", mask: "tor
271271
# Get gradient wrt loss; invert it if attack is targeted
272272
grad = self.estimator.loss_gradient(x=x, y=y) * (1 - 2 * int(self.targeted))
273273

274+
# Apply mask
275+
if mask is not None:
276+
grad = torch.where(mask == 0.0, torch.tensor(0.0), grad)
277+
274278
# Apply norm bound
275279
if self.norm in ["inf", np.inf]:
276280
grad = grad.sign()
@@ -285,10 +289,7 @@ def _compute_perturbation(self, x: "torch.Tensor", y: "torch.Tensor", mask: "tor
285289

286290
assert x.shape == grad.shape
287291

288-
if mask is None:
289-
return grad
290-
else:
291-
return grad * mask
292+
return grad
292293

293294
def _apply_perturbation(
294295
self, x: "torch.Tensor", perturbation: "torch.Tensor", eps_step: Union[int, float, np.ndarray]

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_tensorflow_v2.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
119119
"""
120120
import tensorflow as tf # lgtm [py/repeated-import]
121121

122-
mask = kwargs.get("mask")
123-
if mask is not None and mask.ndim > x.ndim:
124-
raise ValueError("Mask shape must be broadcastable to input shape.")
122+
mask = self._get_mask(x, **kwargs)
125123

126124
# Ensure eps is broadcastable
127125
self._check_compatibility_input_and_eps(x=x)
@@ -224,17 +222,11 @@ def _generate_batch(
224222
225223
:param x: An array with the original inputs.
226224
:param targets: Target values (class labels) one-hot-encoded of shape `(nb_samples, nb_classes)`.
227-
<<<<<<< HEAD
228-
:param mask: An array with a mask to be applied to the adversarial perturbations. Shape needs to be
229-
broadcastable to the shape of x. Any features for which the mask is zero will not be adversarially
230-
perturbed.
231-
:param eps: Maximum perturbation that the attacker can introduce.
232-
:param eps_step: Attack step size (input variation) at each iteration.
233-
=======
234225
:param mask: An array with a mask broadcastable to input `x` defining where to apply adversarial perturbations.
235226
Shape needs to be broadcastable to the shape of x and can also be of the same shape as `x`. Any
236227
features for which the mask is zero will not be adversarially perturbed.
237-
>>>>>>> origin/dev_1.5.0
228+
:param eps: Maximum perturbation that the attacker can introduce.
229+
:param eps_step: Attack step size (input variation) at each iteration.
238230
:return: Adversarial examples.
239231
"""
240232
adv_x = x
@@ -269,6 +261,10 @@ def _compute_perturbation(self, x: "tf.Tensor", y: "tf.Tensor", mask: Optional["
269261
1 - 2 * int(self.targeted), dtype=ART_NUMPY_DTYPE
270262
)
271263

264+
# Apply mask
265+
if mask is not None:
266+
grad = tf.where(mask == 0.0, 0.0, grad)
267+
272268
# Apply norm bound
273269
if self.norm == np.inf:
274270
grad = tf.sign(grad)
@@ -285,10 +281,7 @@ def _compute_perturbation(self, x: "tf.Tensor", y: "tf.Tensor", mask: Optional["
285281

286282
assert x.shape == grad.shape
287283

288-
if mask is None:
289-
return grad
290-
else:
291-
return grad * mask
284+
return grad
292285

293286
def _apply_perturbation(
294287
self, x: "tf.Tensor", perturbation: "tf.Tensor", eps_step: Union[int, float, np.ndarray]

tests/attacks/test_projected_gradient_descent.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _test_backend_mnist(self, classifier, x_train, y_train, x_test, y_test):
158158
# Test the masking
159159
attack = ProjectedGradientDescent(classifier, num_random_init=1)
160160
mask = np.random.binomial(n=1, p=0.5, size=np.prod(x_test.shape))
161-
mask = mask.reshape(x_test.shape)
161+
mask = mask.reshape(x_test.shape).astype(np.float32)
162162

163163
x_test_adv = attack.generate(x_test, mask=mask)
164164
mask_diff = (1 - mask) * (x_test_adv - x_test)
@@ -629,11 +629,11 @@ def _test_framework_vs_numpy(self, classifier):
629629
)
630630

631631
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_train_mnist.shape))
632-
mask = mask.reshape(self.x_train_mnist.shape)
632+
mask = mask.reshape(self.x_train_mnist.shape).astype(np.float32)
633633
x_train_adv_np = attack_np.generate(self.x_train_mnist, mask=mask)
634634

635635
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_test_mnist.shape))
636-
mask = mask.reshape(self.x_test_mnist.shape)
636+
mask = mask.reshape(self.x_test_mnist.shape).astype(np.float32)
637637
x_test_adv_np = attack_np.generate(self.x_test_mnist, mask=mask)
638638

639639
master_seed(1234)
@@ -650,11 +650,11 @@ def _test_framework_vs_numpy(self, classifier):
650650
)
651651

652652
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_train_mnist.shape))
653-
mask = mask.reshape(self.x_train_mnist.shape)
653+
mask = mask.reshape(self.x_train_mnist.shape).astype(np.float32)
654654
x_train_adv_fw = attack_fw.generate(self.x_train_mnist, mask=mask)
655655

656656
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_test_mnist.shape))
657-
mask = mask.reshape(self.x_test_mnist.shape)
657+
mask = mask.reshape(self.x_test_mnist.shape).astype(np.float32)
658658
x_test_adv_fw = attack_fw.generate(self.x_test_mnist, mask=mask)
659659

660660
# Test
@@ -680,11 +680,11 @@ def _test_framework_vs_numpy(self, classifier):
680680
)
681681

682682
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_train_mnist.shape[1:]))
683-
mask = mask.reshape(self.x_train_mnist.shape[1:])
683+
mask = mask.reshape(self.x_train_mnist.shape[1:]).astype(np.float32)
684684
x_train_adv_np = attack_np.generate(self.x_train_mnist, mask=mask)
685685

686686
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_test_mnist.shape[1:]))
687-
mask = mask.reshape(self.x_test_mnist.shape[1:])
687+
mask = mask.reshape(self.x_test_mnist.shape[1:]).astype(np.float32)
688688
x_test_adv_np = attack_np.generate(self.x_test_mnist, mask=mask)
689689

690690
master_seed(1234)
@@ -701,11 +701,11 @@ def _test_framework_vs_numpy(self, classifier):
701701
)
702702

703703
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_train_mnist.shape[1:]))
704-
mask = mask.reshape(self.x_train_mnist.shape[1:])
704+
mask = mask.reshape(self.x_train_mnist.shape[1:]).astype(np.float32)
705705
x_train_adv_fw = attack_fw.generate(self.x_train_mnist, mask=mask)
706706

707707
mask = np.random.binomial(n=1, p=0.5, size=np.prod(self.x_test_mnist.shape[1:]))
708-
mask = mask.reshape(self.x_test_mnist.shape[1:])
708+
mask = mask.reshape(self.x_test_mnist.shape[1:]).astype(np.float32)
709709
x_test_adv_fw = attack_fw.generate(self.x_test_mnist, mask=mask)
710710

711711
# Test

tests/attacks/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def backend_masked_images(attack, fix_get_mnist_subset):
207207

208208
# generate a random mask:
209209
mask = np.random.binomial(n=1, p=0.5, size=np.prod(x_test_mnist.shape))
210-
mask = mask.reshape(x_test_mnist.shape)
210+
mask = mask.reshape(x_test_mnist.shape).astype(np.float32)
211211

212212
x_test_adv = attack.generate(x_test_mnist, mask=mask)
213213
mask_diff = (1 - mask) * (x_test_adv - x_test_mnist)

0 commit comments

Comments
 (0)