Skip to content

Commit 2f12e2e

Browse files
authored
Merge pull request #755 from Trusted-AI/development_issue_60
Account for fractional batches in ZOO attack
2 parents 93079d9 + a3d948d commit 2f12e2e

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

art/attacks/evasion/zoo.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,10 @@ def compare(object1, object2):
344344
else:
345345
x_orig = x_batch
346346
self._reset_adam(np.prod(self.estimator.input_shape).item())
347-
self._current_noise.fill(0)
347+
if x_batch.shape == self._current_noise.shape:
348+
self._current_noise.fill(0)
349+
else:
350+
self._current_noise = np.zeros(x_batch.shape, dtype=ART_NUMPY_DTYPE)
348351
x_adv = x_orig.copy()
349352

350353
# Initialize best distortions, best changed labels and best attacks
@@ -537,7 +540,10 @@ def _resize_image(self, x: np.ndarray, size_x: int, size_y: int, reset: bool = F
537540
# Reset variables to original size and value
538541
if dims == x.shape:
539542
resized_x = x
540-
self._current_noise.fill(0)
543+
if x.shape == self._current_noise.shape:
544+
self._current_noise.fill(0)
545+
else:
546+
self._current_noise = np.zeros(x.shape, dtype=ART_NUMPY_DTYPE)
541547
else:
542548
resized_x = zoom(x, (1, dims[1] / x.shape[1], dims[2] / x.shape[2], dims[3] / x.shape[3],),)
543549
self._current_noise = np.zeros(dims, dtype=ART_NUMPY_DTYPE)

0 commit comments

Comments
 (0)