@@ -12,9 +12,9 @@ class FastGradientMethod(Attack):
1212 Gradient Sign Method"). This implementation extends the attack to other norms, and is therefore called the Fast
1313 Gradient Method. Paper link: https://arxiv.org/abs/1412.6572
1414 """
15- attack_params = Attack .attack_params + ['norm' , 'eps' , 'targeted' , 'random_init' ]
15+ attack_params = Attack .attack_params + ['norm' , 'eps' , 'targeted' , 'random_init' , 'batch_size' ]
1616
17- def __init__ (self , classifier , norm = np .inf , eps = .3 , targeted = False , random_init = False ):
17+ def __init__ (self , classifier , norm = np .inf , eps = .3 , targeted = False , random_init = False , batch_size = 128 ):
1818 """
1919 Create a :class:`FastGradientMethod` instance.
2020
@@ -28,13 +28,16 @@ def __init__(self, classifier, norm=np.inf, eps=.3, targeted=False, random_init=
2828 :type targeted: `bool`
2929 :param random_init: Whether to start at the original input or a random point within the epsilon ball
3030 :type random_init: `bool`
31+ :param batch_size: Batch size
32+ :type batch_size: `int`
3133 """
3234 super (FastGradientMethod , self ).__init__ (classifier )
3335
3436 self .norm = norm
3537 self .eps = eps
3638 self .targeted = targeted
3739 self .random_init = random_init
40+ self .batch_size = batch_size
3841
3942 def _minimal_perturbation (self , x , y , eps_step = 0.1 , eps_max = 1. , ** kwargs ):
4043 """Iteratively compute the minimal perturbation necessary to make the class prediction change. Stop when the
@@ -55,9 +58,8 @@ def _minimal_perturbation(self, x, y, eps_step=0.1, eps_max=1., **kwargs):
5558 adv_x = x .copy ()
5659
5760 # Compute perturbation with implicit batching
58- batch_size = 128
59- for batch_id in range (adv_x .shape [0 ] // batch_size + 1 ):
60- batch_index_1 , batch_index_2 = batch_id * batch_size , min ((batch_id + 1 ) * batch_size , x .shape [0 ])
61+ for batch_id in range (int (np .ceil (adv_x .shape [0 ] / float (self .batch_size )))):
62+ batch_index_1 , batch_index_2 = batch_id * self .batch_size , (batch_id + 1 ) * self .batch_size
6163 batch = adv_x [batch_index_1 :batch_index_2 ]
6264 batch_labels = y [batch_index_1 :batch_index_2 ]
6365
@@ -101,6 +103,8 @@ def generate(self, x, **kwargs):
101103 :type minimal: `bool`
102104 :param random_init: Whether to start at the original input or a random point within the epsilon ball
103105 :type random_init: `bool`
106+ :param batch_size: Batch size
107+ :type batch_size: `int`
104108 :return: An array holding the adversarial examples.
105109 :rtype: `np.ndarray`
106110 """
@@ -134,6 +138,8 @@ def set_params(self, **kwargs):
134138 :type eps: `float`
135139 :param targeted: Should the attack target one specific class
136140 :type targeted: `bool`
141+ :param batch_size: Batch size
142+ :type batch_size: `int`
137143 """
138144 # Save attack-specific parameters
139145 super (FastGradientMethod , self ).set_params (** kwargs )
@@ -144,6 +150,10 @@ def set_params(self, **kwargs):
144150
145151 if self .eps <= 0 :
146152 raise ValueError ('The perturbation size `eps` has to be positive.' )
153+
154+ if self .batch_size <= 0 :
155+ raise ValueError ('The batch size `batch_size` has to be positive.' )
156+
147157 return True
148158
149159 def _compute_perturbation (self , batch , batch_labels ):
@@ -179,9 +189,8 @@ def _compute(self, x, y, eps, random_init):
179189 adv_x = x .copy ()
180190
181191 # Compute perturbation with implicit batching
182- batch_size = 128
183- for batch_id in range (adv_x .shape [0 ] // batch_size + 1 ):
184- batch_index_1 , batch_index_2 = batch_id * batch_size , (batch_id + 1 ) * batch_size
192+ for batch_id in range (int (np .ceil (adv_x .shape [0 ] / float (self .batch_size )))):
193+ batch_index_1 , batch_index_2 = batch_id * self .batch_size , (batch_id + 1 ) * self .batch_size
185194 batch = adv_x [batch_index_1 :batch_index_2 ]
186195 batch_labels = y [batch_index_1 :batch_index_2 ]
187196
0 commit comments