Skip to content

Commit 0228551

Browse files
authored
Merge branch 'dev_1.5.2' into development_issue_837
2 parents 41dd273 + a5953ec commit 0228551

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

art/attacks/evasion/shadow_attack.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,12 @@ def _get_regularisation_loss_gradients(self, perturbation: np.ndarray) -> np.nda
174174
:param perturbation: The perturbation to be regularised.
175175
:return: The loss gradients of the perturbation.
176176
"""
177+
if not self.estimator.channels_first:
178+
perturbation = perturbation.transpose((0, 3, 1, 2))
179+
177180
if self.framework == "tensorflow":
178181
import tensorflow as tf
179182

180-
if not self.estimator.channels_first:
181-
perturbation = perturbation.transpose((0, 3, 1, 2))
182-
183183
if tf.executing_eagerly():
184184
with tf.GradientTape() as tape:
185185

@@ -205,9 +205,6 @@ def _get_regularisation_loss_gradients(self, perturbation: np.ndarray) -> np.nda
205205
loss = self.lambda_tv * loss_tv + self.lambda_s * loss_s + self.lambda_c * loss_c
206206
gradients = tape.gradient(loss, perturbation_t).numpy()
207207

208-
if not self.estimator.channels_first:
209-
gradients = gradients.transpose(0, 2, 3, 1)
210-
211208
else:
212209
raise ValueError("Expecting eager execution.")
213210

@@ -238,6 +235,9 @@ def _get_regularisation_loss_gradients(self, perturbation: np.ndarray) -> np.nda
238235
else:
239236
raise NotImplementedError
240237

238+
if not self.estimator.channels_first:
239+
gradients = gradients.transpose(0, 2, 3, 1)
240+
241241
return gradients
242242

243243
def _check_params(self) -> None:

0 commit comments

Comments
 (0)