Skip to content

Commit 8b02a8f

Browse files
authored
Merge pull request #883 from Trusted-AI/development_issue_824
Check for NaN in loss gradient of PGD attacks and replace with 0.0
2 parents b535c4c + ecae760 commit 8b02a8f

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

art/attacks/evasion/fast_gradient.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,20 @@ def _compute_perturbation(
340340
# Get gradient wrt loss; invert it if attack is targeted
341341
grad = self.estimator.loss_gradient(batch, batch_labels) * (1 - 2 * int(self.targeted))
342342

343+
# Check for NaN before normalisation an replace with 0
344+
if np.isnan(grad).any():
345+
logger.warning("Elements of the loss gradient are NaN and have been replaced with 0.0.")
346+
grad = np.where(np.isnan(grad), 0.0, grad)
347+
343348
# Apply mask
344349
if mask is not None:
345350
grad = np.where(mask == 0.0, 0.0, grad)
346351

347352
# Apply norm bound
348353
def _apply_norm(grad, object_type=False):
354+
if np.isinf(grad).any():
355+
logger.info("The loss gradient array contains at least one positive or negative infinity.")
356+
349357
if self.norm in [np.inf, "inf"]:
350358
grad = np.sign(grad)
351359
elif self.norm == 1:

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_pytorch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ def _compute_perturbation(
278278
# Get gradient wrt loss; invert it if attack is targeted
279279
grad = self.estimator.loss_gradient(x=x, y=y) * (1 - 2 * int(self.targeted))
280280

281+
# Check for nan before normalisation an replace with 0
282+
if torch.any(grad.isnan()):
283+
logger.warning("Elements of the loss gradient are NaN and have been replaced with 0.0.")
284+
grad[grad.isnan()] = 0.0
285+
281286
# Apply mask
282287
if mask is not None:
283288
grad = torch.where(mask == 0.0, torch.tensor(0.0), grad)

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_tensorflow_v2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,11 @@ def _compute_perturbation(self, x: "tf.Tensor", y: "tf.Tensor", mask: Optional["
269269
1 - 2 * int(self.targeted), dtype=ART_NUMPY_DTYPE
270270
)
271271

272+
# Check for NaN before normalisation an replace with 0
273+
if tf.reduce_any(tf.math.is_nan(grad)):
274+
logger.warning("Elements of the loss gradient are NaN and have been replaced with 0.0.")
275+
grad = tf.where(tf.math.is_nan(grad), tf.zeros_like(grad), grad)
276+
272277
# Apply mask
273278
if mask is not None:
274279
grad = tf.where(mask == 0.0, 0.0, grad)

0 commit comments

Comments
 (0)