@@ -84,7 +84,7 @@ def __init__(
8484 Create a :class:`.FastGradientMethod` instance.
8585
8686 :param estimator: A trained classifier.
87- :param norm: The norm of the adversarial perturbation. Possible values: "inf", np.inf, 1 or 2 .
87+ :param norm: The norm of the adversarial perturbation. Possible values: "inf", ` np.inf` or a real `p >= 1` .
8888 :param eps: Attack step size (input variation).
8989 :param eps_step: Step size of input variation for minimal perturbation computation.
9090 :param targeted: Indicates whether the attack is targeted (True) or untargeted (False)
@@ -288,16 +288,18 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
288288
289289 logger .info (
290290 "Success rate of FGM attack: %.2f%%" ,
291- rate_best
292- if rate_best is not None
293- else 100
294- * compute_success (
295- self .estimator , # type: ignore
296- x ,
297- y_array ,
298- adv_x_best ,
299- self .targeted ,
300- batch_size = self .batch_size ,
291+ (
292+ rate_best
293+ if rate_best is not None
294+ else 100
295+ * compute_success (
296+ self .estimator , # type: ignore
297+ x ,
298+ y_array ,
299+ adv_x_best ,
300+ self .targeted ,
301+ batch_size = self .batch_size ,
302+ )
301303 ),
302304 )
303305
@@ -334,8 +336,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
334336
335337 def _check_params (self ) -> None :
336338
337- if self .norm not in [1 , 2 , np .inf , "inf" ]:
338- raise ValueError ('Norm order must be either 1, 2, `np.inf` or "inf".' )
339+ norm : float = np .inf if self .norm == "inf" else float (self .norm )
340+ if norm < 1 :
341+ raise ValueError ('Norm order must be either "inf", `np.inf` or a real `p >= 1`.' )
339342
340343 if not (
341344 isinstance (self .eps , (int , float ))
@@ -391,9 +394,6 @@ def _compute_perturbation(
391394 decay : Optional [float ] = None ,
392395 momentum : Optional [np .ndarray ] = None ,
393396 ) -> np .ndarray :
394- # Pick a small scalar to avoid division by 0
395- tol = 10e-8
396-
397397 # Get gradient wrt loss; invert it if attack is targeted
398398 grad = self .estimator .loss_gradient (x , y ) * (1 - 2 * int (self .targeted ))
399399
@@ -426,32 +426,39 @@ def _compute_perturbation(
426426
427427 # Apply norm bound
428428 def _apply_norm (norm , grad , object_type = False ):
429+ """Returns an x maximizing <grad, x> subject to ||x||_norm<=1."""
429430 if (grad .dtype != object and np .isinf (grad ).any ()) or np .isnan ( # pragma: no cover
430431 grad .astype (np .float32 )
431432 ).any ():
432433 logger .info ("The loss gradient array contains at least one positive or negative infinity." )
433434
435+ grad_2d = grad .reshape (1 if object_type else len (grad ), - 1 )
434436 if norm in [np .inf , "inf" ]:
435- grad = np .sign ( grad )
437+ grad_2d = np .ones_like ( grad_2d )
436438 elif norm == 1 :
437- if not object_type :
438- ind = tuple (range (1 , len (x .shape )))
439- else :
440- ind = None
441- grad = grad / (np .sum (np .abs (grad ), axis = ind , keepdims = True ) + tol )
442- elif norm == 2 :
443- if not object_type :
444- ind = tuple (range (1 , len (x .shape )))
445- else :
446- ind = None
447- grad = grad / (np .sqrt (np .sum (np .square (grad ), axis = ind , keepdims = True )) + tol )
439+ i_max = np .argmax (np .abs (grad_2d ), axis = 1 )
440+ grad_2d = np .zeros_like (grad_2d )
441+ grad_2d [range (len (grad_2d )), i_max ] = 1
442+ elif norm > 1 :
443+ conjugate = norm / (norm - 1 )
444+ q_norm = np .linalg .norm (grad_2d , ord = conjugate , axis = 1 , keepdims = True )
445+ grad_2d = (np .abs (grad_2d ) / np .where (q_norm , q_norm , np .inf )) ** (conjugate - 1 )
446+ grad = grad_2d .reshape (grad .shape ) * np .sign (grad )
448447 return grad
449448
450- # Add momentum
449+ # Compute gradient momentum
451450 if decay is not None and momentum is not None :
452- grad = _apply_norm (norm = 1 , grad = grad )
453- grad = decay * momentum + grad
454- momentum += grad
451+ if x .dtype == object :
452+ raise NotImplementedError ("Momentum Iterative Method not yet implemented for object type input." )
453+ # Update momentum in-place (important).
454+ # The L1 normalization for accumulation is an arbitrary choice of the paper.
455+ grad_2d = grad .reshape (len (grad ), - 1 )
456+ norm1 = np .linalg .norm (grad_2d , ord = 1 , axis = 1 , keepdims = True )
457+ normalized_grad = (grad_2d / np .where (norm1 , norm1 , np .inf )).reshape (grad .shape )
458+ momentum *= decay
459+ momentum += normalized_grad
460+ # Use the momentum to compute the perturbation, instead of the gradient
461+ grad = momentum
455462
456463 if x .dtype == object :
457464 for i_sample in range (x .shape [0 ]):
0 commit comments