@@ -231,28 +231,33 @@ def __init__(self, reduction="mean"):
231231 self .ce_loss = torch .nn .CrossEntropyLoss (reduction = "none" )
232232 self .reduction = reduction
233233
234- def __call__ (self , y_true : torch .Tensor , y_pred : torch .Tensor ) -> torch .Tensor :
234+ def __call__ (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
235235 if self .reduction == "mean" :
236- return self .ce_loss (y_true , y_pred ).mean ()
236+ return self .ce_loss (y_pred , y_true ).mean ()
237237 if self .reduction == "sum" :
238- return self .ce_loss (y_true , y_pred ).sum ()
238+ return self .ce_loss (y_pred , y_true ).sum ()
239239 if self .reduction == "none" :
240- return self .ce_loss (y_true , y_pred )
240+ return self .ce_loss (y_pred , y_true )
241241 raise NotImplementedError ()
242242
243243 def forward (
244244 self , input : torch .Tensor , target : torch .Tensor # pylint: disable=W0622
245245 ) -> torch .Tensor :
246246 """
247247 Forward method.
248+
248249 :param input: Predicted labels of shape (nb_samples, nb_classes).
249250 :param target: Target labels of shape (nb_samples, nb_classes).
250251 :return: Difference Logits Ratio Loss.
251252 """
252- return self .__call__ (y_true = target , y_pred = input )
253+ return self .__call__ (y_pred = input , y_true = target )
253254
254255 _loss_object_pt : torch .nn .modules .loss ._Loss = CrossEntropyLossTorch (reduction = "mean" )
255256
257+ reduce_labels = True
258+ int_labels = True
259+ probability_labels = True
260+
256261 elif loss_type == "difference_logits_ratio" :
257262 if is_probability (
258263 estimator .predict (x = np .ones (shape = (1 , * estimator .input_shape ), dtype = ART_NUMPY_DTYPE ))
@@ -316,13 +321,18 @@ def forward(
316321 ) -> torch .Tensor :
317322 """
318323 Forward method.
324+
319325 :param input: Predicted labels of shape (nb_samples, nb_classes).
320326 :param target: Target labels of shape (nb_samples, nb_classes).
321327 :return: Difference Logits Ratio Loss.
322328 """
323329 return self .__call__ (y_true = target , y_pred = input )
324330
325331 _loss_object_pt = DifferenceLogitsRatioPyTorch ()
332+
333+ reduce_labels = False
334+ int_labels = False
335+ probability_labels = False
326336 else :
327337 raise NotImplementedError ()
328338
@@ -340,6 +350,10 @@ def forward(
340350 device_type = str (estimator ._device ),
341351 )
342352
353+ estimator_apgd ._reduce_labels = reduce_labels
354+ estimator_apgd ._int_labels = int_labels
355+ estimator_apgd ._probability_labels = probability_labels
356+
343357 else : # pragma: no cover
344358 raise ValueError (f"The loss type { loss_type } is not supported for the provided estimator." )
345359
0 commit comments