4040from art .utils import check_and_transform_label_format
4141
4242if TYPE_CHECKING :
43- # pylint: disable=C0412
43+ # pylint: disable=C0412, C0302
4444 import torch
4545
4646 from art .utils import CLIP_VALUES_TYPE , PREPROCESSING_TYPE
@@ -266,21 +266,26 @@ def reduce_labels(self, y: Union[np.ndarray, "torch.Tensor"]) -> Union[np.ndarra
266266 """
267267 Reduce labels from one-hot encoded to index labels.
268268 """
269+ # pylint: disable=R0911
269270 import torch # lgtm [py/repeated-import]
270271
271272 # Check if the loss function requires as input index labels instead of one-hot-encoded labels
272- if self ._reduce_labels and self ._int_labels :
273- if isinstance (y , torch .Tensor ):
274- return torch .argmax (y , dim = 1 )
275- return np .argmax (y , axis = 1 )
276-
277- if self ._reduce_labels : # float labels
273+ # Checking for exactly 2 classes to support binary classification
274+ if self .nb_classes > 2 :
275+ if self ._reduce_labels and self ._int_labels :
276+ if isinstance (y , torch .Tensor ):
277+ return torch .argmax (y , dim = 1 )
278+ return np .argmax (y , axis = 1 )
279+ if self ._reduce_labels : # float labels
280+ if isinstance (y , torch .Tensor ):
281+ return torch .argmax (y , dim = 1 ).type ("torch.FloatTensor" )
282+ y_index = np .argmax (y , axis = 1 ).astype (np .float32 )
283+ y_index = np .expand_dims (y_index , axis = 1 )
284+ return y_index
285+ else :
278286 if isinstance (y , torch .Tensor ):
279- return torch .argmax (y , dim = 1 ).type ("torch.FloatTensor" )
280- y_index = np .argmax (y , axis = 1 ).astype (np .float32 )
281- y_index = np .expand_dims (y_index , axis = 1 )
282- return y_index
283-
287+ return y .float ()
288+ return y .astype (np .float32 )
284289 return y
285290
286291 def predict ( # pylint: disable=W0221
@@ -302,8 +307,9 @@ def predict( # pylint: disable=W0221
302307 # Apply preprocessing
303308 x_preprocessed , _ = self ._apply_preprocessing (x , y = None , fit = False )
304309
310+ results_list = []
311+
305312 # Run prediction with batch processing
306- results = np .zeros ((x_preprocessed .shape [0 ], self .nb_classes ), dtype = np .float32 )
307313 num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
308314 for m in range (num_batch ):
309315 # Batch indexes
@@ -315,8 +321,13 @@ def predict( # pylint: disable=W0221
315321 with torch .no_grad ():
316322 model_outputs = self ._model (torch .from_numpy (x_preprocessed [begin :end ]).to (self ._device ))
317323 output = model_outputs [- 1 ]
318- results [begin :end ] = output .detach ().cpu ().numpy ()
324+ output = output .detach ().cpu ().numpy ().astype (np .float32 )
325+ if len (output .shape ) == 1 :
326+ output = np .expand_dims (output .detach ().cpu ().numpy (), axis = 1 ).astype (np .float32 )
327+
328+ results_list .append (output )
319329
330+ results = np .vstack (results_list )
320331 # Apply postprocessing
321332 predictions = self ._apply_postprocessing (preds = results , fit = False )
322333
@@ -577,7 +588,12 @@ def hook(grad):
577588
578589 self ._model .zero_grad ()
579590 if label is None :
580- for i in range (self .nb_classes ):
591+ if len (preds .shape ) == 1 or preds .shape [1 ] == 1 :
592+ num_outputs = 1
593+ else :
594+ num_outputs = self .nb_classes
595+
596+ for i in range (num_outputs ):
581597 torch .autograd .backward (
582598 preds [:, i ],
583599 torch .tensor ([1.0 ] * len (preds [:, 0 ])).to (self ._device ),
0 commit comments