@@ -174,12 +174,12 @@ def _get_regularisation_loss_gradients(self, perturbation: np.ndarray) -> np.nda
174174 :param perturbation: The perturbation to be regularised.
175175 :return: The loss gradients of the perturbation.
176176 """
177+ if not self .estimator .channels_first :
178+ perturbation = perturbation .transpose ((0 , 3 , 1 , 2 ))
179+
177180 if self .framework == "tensorflow" :
178181 import tensorflow as tf
179182
180- if not self .estimator .channels_first :
181- perturbation = perturbation .transpose ((0 , 3 , 1 , 2 ))
182-
183183 if tf .executing_eagerly ():
184184 with tf .GradientTape () as tape :
185185
@@ -205,9 +205,6 @@ def _get_regularisation_loss_gradients(self, perturbation: np.ndarray) -> np.nda
205205 loss = self .lambda_tv * loss_tv + self .lambda_s * loss_s + self .lambda_c * loss_c
206206 gradients = tape .gradient (loss , perturbation_t ).numpy ()
207207
208- if not self .estimator .channels_first :
209- gradients = gradients .transpose (0 , 2 , 3 , 1 )
210-
211208 else :
212209 raise ValueError ("Expecting eager execution." )
213210
@@ -238,6 +235,9 @@ def _get_regularisation_loss_gradients(self, perturbation: np.ndarray) -> np.nda
238235 else :
239236 raise NotImplementedError
240237
238+ if not self .estimator .channels_first :
239+ gradients = gradients .transpose (0 , 2 , 3 , 1 )
240+
241241 return gradients
242242
243243 def _check_params (self ) -> None :
0 commit comments