Skip to content

Commit beca4c8

Browse files
sdaultonfacebook-github-bot
authored andcommitted
fix uint8 bug in CV (#224)
Summary: Pull Request resolved: #224 see title Reviewed By: Balandat Differential Revision: D16721744 fbshipit-source-id: dab8d9c9d1b85a87367e6ddd2fe4c68f43c06583
1 parent 3bde3b6 commit beca4c8

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

botorch/cross_validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def gen_loo_cv_folds(
6868
>>> cv_folds = gen_loo_cv_folds(train_X, train_Y)
6969
"""
7070
masks = torch.eye(train_X.shape[-2], dtype=torch.uint8, device=train_X.device)
71+
masks = masks.to(dtype=torch.bool)
7172
if train_Y.dim() < train_X.dim():
7273
# add output dimension
7374
train_Y = train_Y.unsqueeze(-1)

test/test_cross_validation.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,35 @@ def test_single_task_batch_cv(self, cuda=False):
4343
noiseless_cv_folds = gen_loo_cv_folds(
4444
train_X=train_X, train_Y=train_Y
4545
)
46+
# check shapes
47+
expected_shape_train_X = batch_shape + torch.Size(
48+
[n, n - 1, train_X.shape[-1]]
49+
)
50+
expected_shape_test_X = batch_shape + torch.Size(
51+
[n, 1, train_X.shape[-1]]
52+
)
53+
self.assertEqual(
54+
noiseless_cv_folds.train_X.shape, expected_shape_train_X
55+
)
56+
self.assertEqual(
57+
noiseless_cv_folds.test_X.shape, expected_shape_test_X
58+
)
59+
60+
expected_shape_train_Y = batch_shape + torch.Size(
61+
[n, n - 1, num_outputs]
62+
)
63+
expected_shape_test_Y = batch_shape + torch.Size(
64+
[n, 1, num_outputs]
65+
)
66+
67+
self.assertEqual(
68+
noiseless_cv_folds.train_Y.shape, expected_shape_train_Y
69+
)
70+
self.assertEqual(
71+
noiseless_cv_folds.test_Y.shape, expected_shape_test_Y
72+
)
73+
self.assertIsNone(noiseless_cv_folds.train_Yvar)
74+
self.assertIsNone(noiseless_cv_folds.test_Yvar)
4675
# Test SingleTaskGP
4776
with warnings.catch_warnings():
4877
warnings.filterwarnings("ignore", category=OptimizationWarning)
@@ -60,6 +89,21 @@ def test_single_task_batch_cv(self, cuda=False):
6089
noisy_cv_folds = gen_loo_cv_folds(
6190
train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar
6291
)
92+
# check shapes
93+
self.assertEqual(
94+
noisy_cv_folds.train_X.shape, expected_shape_train_X
95+
)
96+
self.assertEqual(noisy_cv_folds.test_X.shape, expected_shape_test_X)
97+
self.assertEqual(
98+
noisy_cv_folds.train_Y.shape, expected_shape_train_Y
99+
)
100+
self.assertEqual(noisy_cv_folds.test_Y.shape, expected_shape_test_Y)
101+
self.assertEqual(
102+
noisy_cv_folds.train_Yvar.shape, expected_shape_train_Y
103+
)
104+
self.assertEqual(
105+
noisy_cv_folds.test_Yvar.shape, expected_shape_test_Y
106+
)
63107
with warnings.catch_warnings():
64108
warnings.filterwarnings("ignore", category=OptimizationWarning)
65109
cv_results = batch_cross_validation(

0 commit comments

Comments
 (0)