Skip to content

Commit 24260b5

Browse files
author
Beat Buesser
committed
Update definition of number of repetitions
Signed-off-by: Beat Buesser <[email protected]>
1 parent 570d5ee commit 24260b5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,13 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> T
335335
tf.data.Dataset.from_tensor_slices((x, y))
336336
.shuffle(10000)
337337
.batch(self.batch_size)
338-
.repeat(math.ceil(self.max_iter / (x.shape[0] / self.batch_size)))
338+
.repeat(math.ceil(x.shape[0] / self.batch_size))
339339
)
340340
else:
341341
ds = (
342342
tf.data.Dataset.from_tensor_slices((x, y))
343343
.batch(self.batch_size)
344-
.repeat(math.ceil(self.max_iter / (x.shape[0] / self.batch_size)))
344+
.repeat(math.ceil(x.shape[0] / self.batch_size))
345345
)
346346

347347
for _ in trange(self.max_iter, desc="Adversarial Patch TensorFlow v2"):

0 commit comments

Comments
 (0)