Skip to content

Commit 509e223

Browse files
authored
Merge pull request #2214 from Trusted-AI/development_issue_2165
Add support for arbitrary-shaped input in APGD and ACG attacks
2 parents c6da8c4 + 7aa1f1b commit 509e223

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

art/attacks/evasion/auto_conjugate_gradient.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
463463

464464
# self.eta = np.full((self.batch_size, 1, 1, 1), 2 * self.eps_step).astype(ART_NUMPY_DTYPE)
465465
_batch_size = x_k.shape[0]
466-
eta = np.full((_batch_size, 1, 1, 1), self.eps_step).astype(ART_NUMPY_DTYPE)
466+
eta = np.full((_batch_size,) + (1,) * len(self.estimator.input_shape), self.eps_step).astype(
467+
ART_NUMPY_DTYPE
468+
)
467469
self.count_condition_1 = np.zeros(shape=(_batch_size,))
468470
gradk_1 = np.zeros_like(x_k)
469471
cgradk_1 = np.zeros_like(x_k)
@@ -650,4 +652,4 @@ def get_beta(gradk, gradk_1, cgradk_1):
650652
betak = -(_gradk * delta_gradk).sum(axis=1) / (
651653
(_cgradk_1 * delta_gradk).sum(axis=1) + np.finfo(ART_NUMPY_DTYPE).eps
652654
)
653-
return betak.reshape((_batch_size, 1, 1, 1))
655+
return betak.reshape((_batch_size,) + (1,) * (len(gradk.shape) - 1))

art/attacks/evasion/auto_projected_gradient_descent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
458458

459459
# modification for image-wise stepsize update
460460
_batch_size = x_k.shape[0]
461-
eta = np.full((_batch_size, 1, 1, 1), self.eps_step).astype(ART_NUMPY_DTYPE)
461+
eta = np.full((_batch_size,) + (1,) * len(self.estimator.input_shape), self.eps_step).astype(
462+
ART_NUMPY_DTYPE
463+
)
462464
self.count_condition_1 = np.zeros(shape=(_batch_size,))
463465

464466
for k_iter in trange(self.max_iter, desc="AutoPGD - iteration", leave=False, disable=not self.verbose):

0 commit comments

Comments
 (0)