Skip to content

Commit 41de1f4

Browse files
wjmaddoxBalandat
andauthored
enforce num_classes to be an int (#1728)
Co-authored-by: Max Balandat <[email protected]>
1 parent e4579ed commit 41de1f4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

gpytorch/likelihoods/gaussian_likelihood.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood):
309309
"""
310310

311311
def _prepare_targets(self, targets, alpha_epsilon=0.01, dtype=torch.float):
312-
num_classes = targets.max() + 1
312+
num_classes = int(targets.max() + 1)
313313
# set alpha = \alpha_\epsilon
314314
alpha = alpha_epsilon * torch.ones(targets.shape[-1], num_classes, device=targets.device, dtype=dtype)
315315

0 commit comments

Comments
 (0)