Skip to content

Commit 31209f2

Browse files
authored
Merge branch 'dev_1.13.1' into dev_1.13.1_fix_check_and_transform_label_format
2 parents 326ae44 + e471a27 commit 31209f2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

art/estimators/classification/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def reduce_labels(self, y: Union[np.ndarray, "torch.Tensor"]) -> Union[np.ndarra
272272
# Check if the loss function supports probability labels and probability labels are provided
273273
if self._probability_labels and len(y.shape) == 2:
274274
if isinstance(y, torch.Tensor):
275-
is_one_hot = torch.equal(y.floor(), y.ceil())
275+
is_one_hot = torch.equal(y.float().floor(), y.float().ceil())
276276
else:
277277
is_one_hot = np.array_equal(np.floor(y), np.ceil(y))
278278
if not is_one_hot: # probability labels

0 commit comments

Comments
 (0)