@@ -43,6 +43,35 @@ def test_single_task_batch_cv(self, cuda=False):
4343 noiseless_cv_folds = gen_loo_cv_folds (
4444 train_X = train_X , train_Y = train_Y
4545 )
46+ # check shapes
47+ expected_shape_train_X = batch_shape + torch .Size (
48+ [n , n - 1 , train_X .shape [- 1 ]]
49+ )
50+ expected_shape_test_X = batch_shape + torch .Size (
51+ [n , 1 , train_X .shape [- 1 ]]
52+ )
53+ self .assertEqual (
54+ noiseless_cv_folds .train_X .shape , expected_shape_train_X
55+ )
56+ self .assertEqual (
57+ noiseless_cv_folds .test_X .shape , expected_shape_test_X
58+ )
59+
60+ expected_shape_train_Y = batch_shape + torch .Size (
61+ [n , n - 1 , num_outputs ]
62+ )
63+ expected_shape_test_Y = batch_shape + torch .Size (
64+ [n , 1 , num_outputs ]
65+ )
66+
67+ self .assertEqual (
68+ noiseless_cv_folds .train_Y .shape , expected_shape_train_Y
69+ )
70+ self .assertEqual (
71+ noiseless_cv_folds .test_Y .shape , expected_shape_test_Y
72+ )
73+ self .assertIsNone (noiseless_cv_folds .train_Yvar )
74+ self .assertIsNone (noiseless_cv_folds .test_Yvar )
4675 # Test SingleTaskGP
4776 with warnings .catch_warnings ():
4877 warnings .filterwarnings ("ignore" , category = OptimizationWarning )
@@ -60,6 +89,21 @@ def test_single_task_batch_cv(self, cuda=False):
6089 noisy_cv_folds = gen_loo_cv_folds (
6190 train_X = train_X , train_Y = train_Y , train_Yvar = train_Yvar
6291 )
92+ # check shapes
93+ self .assertEqual (
94+ noisy_cv_folds .train_X .shape , expected_shape_train_X
95+ )
96+ self .assertEqual (noisy_cv_folds .test_X .shape , expected_shape_test_X )
97+ self .assertEqual (
98+ noisy_cv_folds .train_Y .shape , expected_shape_train_Y
99+ )
100+ self .assertEqual (noisy_cv_folds .test_Y .shape , expected_shape_test_Y )
101+ self .assertEqual (
102+ noisy_cv_folds .train_Yvar .shape , expected_shape_train_Y
103+ )
104+ self .assertEqual (
105+ noisy_cv_folds .test_Yvar .shape , expected_shape_test_Y
106+ )
63107 with warnings .catch_warnings ():
64108 warnings .filterwarnings ("ignore" , category = OptimizationWarning )
65109 cv_results = batch_cross_validation (
0 commit comments