Skip to content

Commit e471a27

Browse files
authored
Merge pull request #2022 from Trusted-AI/development_issue_2014
Add conversion to float for comparison on GPU/CUDA
2 parents 3c80f84 + aa15f2e commit e471a27

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

art/estimators/classification/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def reduce_labels(self, y: Union[np.ndarray, "torch.Tensor"]) -> Union[np.ndarra
272272
# Check if the loss function supports probability labels and probability labels are provided
273273
if self._probability_labels and len(y.shape) == 2:
274274
if isinstance(y, torch.Tensor):
275-
is_one_hot = torch.equal(y.floor(), y.ceil())
275+
is_one_hot = torch.equal(y.float().floor(), y.float().ceil())
276276
else:
277277
is_one_hot = np.array_equal(np.floor(y), np.ceil(y))
278278
if not is_one_hot: # probability labels

0 commit comments

Comments
 (0)