@@ -134,9 +134,9 @@ def __init__(
134134 mean_value = (self .estimator .clip_values [1 ] - self .estimator .clip_values [0 ]) / 2.0 + self .estimator .clip_values [
135135 0
136136 ]
137- initial_value = np .ones (self .patch_shape ) * mean_value
137+ self . _initial_value = np .ones (self .patch_shape ) * mean_value
138138 self ._patch = tf .Variable (
139- initial_value = initial_value ,
139+ initial_value = self . _initial_value ,
140140 shape = self .patch_shape ,
141141 dtype = tf .float32 ,
142142 constraint = lambda x : tf .clip_by_value (x , self .estimator .clip_values [0 ], self .estimator .clip_values [1 ]),
@@ -365,10 +365,14 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
365365
366366 :param x: An array with the original input images of shape NHWC or input videos of shape NFHWC.
367367 :param y: An array with the original true labels.
368- :param mask: An boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
368+ :param mask: A boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
369369 (N, H, W) without their channel dimensions. Any features for which the mask is True can be the
370370 center location of the patch during sampling.
371371 :type mask: `np.ndarray`
372+ :param reset_patch: If `True` reset patch to initial values of mean of minimal and maximal clip value, else if
373+ `False` (default) restart from previous patch values created by previous call to `generate`
374+ or mean of minimal and maximal clip value if first call to `generate`.
375+ :type reset_patch: bool
372376 :return: An array with adversarial patch and an array of the patch mask.
373377 """
374378 import tensorflow as tf # lgtm [py/repeated-import]
@@ -393,6 +397,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
393397 if mask is not None and mask .shape [0 ] == 1 :
394398 mask = np .repeat (mask , repeats = x .shape [0 ], axis = 0 )
395399
400+ if kwargs .get ("reset_patch" ):
401+ self .reset_patch (initial_patch_value = self ._initial_value )
402+
396403 y = check_and_transform_label_format (labels = y , nb_classes = self .estimator .nb_classes )
397404
398405 if mask is None :
@@ -460,11 +467,18 @@ def apply_patch(
460467 patch = patch_external if patch_external is not None else self ._patch
461468 return self ._random_overlay (images = x , patch = patch , scale = scale , mask = mask ).numpy ()
462469
463- def reset_patch (self , initial_patch_value : np .ndarray ) -> None :
470+ def reset_patch (self , initial_patch_value : Optional [ Union [ float , np .ndarray ]] = None ) -> None :
464471 """
465472 Reset the adversarial patch.
466473
467474 :param initial_patch_value: Patch value to use for resetting the patch.
468475 """
469- initial_value = np .ones (self .patch_shape ) * initial_patch_value
470- self ._patch .assign (np .ones (shape = self .patch_shape ) * initial_value )
476+ if initial_patch_value is None :
477+ self ._patch .assign (self ._initial_value )
478+ elif isinstance (initial_patch_value , float ):
479+ initial_value = np .ones (self .patch_shape ) * initial_patch_value
480+ self ._patch .assign (initial_value )
481+ elif self ._patch .shape == initial_patch_value .shape :
482+ self ._patch .assign (initial_patch_value )
483+ else :
484+ raise ValueError ("Unexpected value for initial_patch_value." )
0 commit comments