10
10
import torch
11
11
from botorch .acquisition .objective import ScalarizedPosteriorTransform
12
12
from botorch .exceptions .errors import BotorchTensorDimensionError
13
- from botorch .exceptions .warnings import OptimizationWarning
13
+ from botorch .exceptions .warnings import InputDataWarning , OptimizationWarning
14
14
from botorch .fit import fit_gpytorch_mll
15
15
from botorch .models .latent_kronecker_gp import LatentKroneckerGP
16
16
from botorch .models .transforms import Normalize , Standardize
17
+ from botorch .utils .datasets import SupervisedDataset
17
18
from botorch .utils .testing import BotorchTestCase , get_random_data
18
19
from botorch .utils .types import DEFAULT
19
20
from gpytorch .kernels import MaternKernel , RBFKernel , ScaleKernel
@@ -38,7 +39,7 @@ def _get_data_with_missing_entries(
38
39
mask [torch .randperm (n_train * t )[: n_train * t // 2 ]] = False
39
40
train_Y [..., ~ mask .reshape (n_train , t )] = torch .nan
40
41
41
- return train_X , train_T , train_Y
42
+ return train_X , train_T , train_Y , mask
42
43
43
44
44
45
class TestLatentKroneckerGP (BotorchTestCase ):
@@ -71,7 +72,7 @@ def test_default_init(self):
71
72
intf = None
72
73
octf = None
73
74
74
- train_X , train_T , train_Y = _get_data_with_missing_entries (
75
+ train_X , train_T , train_Y , mask = _get_data_with_missing_entries (
75
76
n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
76
77
)
77
78
@@ -85,8 +86,7 @@ def test_default_init(self):
85
86
model .to (** tkwargs )
86
87
87
88
# test init
88
- mask_valid = torch .isfinite (train_Y .reshape (- 1 , n_train , t )[0 ]).flatten ()
89
- train_Y_flat = train_Y .reshape (* batch_shape , - 1 )[..., mask_valid ]
89
+ train_Y_flat = train_Y .reshape (* batch_shape , - 1 )[..., mask ]
90
90
if use_transforms :
91
91
self .assertIsInstance (model .input_transform , Normalize )
92
92
self .assertIsInstance (model .outcome_transform , Standardize )
@@ -124,7 +124,7 @@ def test_custom_init(self):
124
124
):
125
125
tkwargs = {"device" : self .device , "dtype" : dtype }
126
126
127
- train_X , train_T , train_Y = _get_data_with_missing_entries (
127
+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
128
128
n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
129
129
)
130
130
@@ -230,7 +230,7 @@ def test_gp_train(self):
230
230
intf = None
231
231
octf = None
232
232
233
- train_X , train_T , train_Y = _get_data_with_missing_entries (
233
+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
234
234
n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
235
235
)
236
236
@@ -271,7 +271,7 @@ def _test_gp_eval_shapes(
271
271
intf = None
272
272
octf = None
273
273
274
- train_X , train_T , train_Y = _get_data_with_missing_entries (
274
+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
275
275
n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
276
276
)
277
277
@@ -441,7 +441,7 @@ def test_gp_eval_values(self):
441
441
intf = None
442
442
octf = None
443
443
444
- train_X , train_T , train_Y = _get_data_with_missing_entries (
444
+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
445
445
n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
446
446
)
447
447
@@ -507,7 +507,7 @@ def test_iterative_methods(self):
507
507
batch_shape = torch .Size ([])
508
508
tkwargs = {"device" : self .device , "dtype" : torch .double }
509
509
510
- train_X , train_T , train_Y = _get_data_with_missing_entries (
510
+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
511
511
n_train = 10 , d = 1 , t = 1 , batch_shape = batch_shape , tkwargs = tkwargs
512
512
)
513
513
@@ -525,7 +525,7 @@ def test_not_implemented(self):
525
525
batch_shape = torch .Size ([])
526
526
tkwargs = {"device" : self .device , "dtype" : torch .double }
527
527
528
- train_X , train_T , train_Y = _get_data_with_missing_entries (
528
+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
529
529
n_train = 10 , d = 1 , t = 1 , batch_shape = batch_shape , tkwargs = tkwargs
530
530
)
531
531
@@ -558,3 +558,63 @@ def test_not_implemented(self):
558
558
err_msg = f"Only GaussianLikelihood currently supported for { cls_name } "
559
559
with self .assertRaisesRegex (NotImplementedError , err_msg ):
560
560
model .posterior (train_X )
561
+
562
+ def test_construct_inputs (self ) -> None :
563
+ # This test relies on the fact that the random (missing) data generation
564
+ # does not remove all occurrences of a particular X or T value. Therefore,
565
+ # we fix the random seed and set n_train and t to slightly larger values.
566
+
567
+ torch .manual_seed (12345 )
568
+ for batch_shape , n_train , d , t , dtype in itertools .product (
569
+ ( # batch_shape
570
+ torch .Size ([]),
571
+ torch .Size ([1 ]),
572
+ torch .Size ([2 ]),
573
+ torch .Size ([2 , 3 ]),
574
+ ),
575
+ (15 ,), # n_train
576
+ (1 , 2 ), # d
577
+ (10 ,), # t
578
+ (torch .float , torch .double ), # dtype
579
+ ):
580
+ tkwargs = {"device" : self .device , "dtype" : dtype }
581
+
582
+ train_X , train_T , train_Y , mask = _get_data_with_missing_entries (
583
+ n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
584
+ )
585
+
586
+ train_X_supervised = torch .cat (
587
+ [
588
+ train_X .repeat_interleave (t , dim = - 2 ),
589
+ train_T .repeat (* ([1 ] * len (batch_shape )), n_train , 1 ),
590
+ ],
591
+ dim = - 1 ,
592
+ )
593
+ train_Y_supervised = train_Y .reshape (* batch_shape , n_train * t , 1 )
594
+
595
+ # randomly permute data to test robustness to non-contiguous data
596
+ idx = torch .randperm (n_train * t , device = self .device )
597
+ train_X_supervised = train_X_supervised [..., idx , :][..., mask [idx ], :]
598
+ train_Y_supervised = train_Y_supervised [..., idx , :][..., mask [idx ], :]
599
+
600
+ dataset = SupervisedDataset (
601
+ X = train_X_supervised ,
602
+ Y = train_Y_supervised ,
603
+ Yvar = train_Y_supervised , # just to check warning
604
+ feature_names = [f"x_{ i } " for i in range (d )] + ["step" ],
605
+ outcome_names = ["y" ],
606
+ )
607
+
608
+ w_msg = "Ignoring Yvar values in provided training data, because "
609
+ w_msg += "they are currently not supported by LatentKroneckerGP."
610
+ with self .assertWarnsRegex (InputDataWarning , w_msg ):
611
+ model_inputs = LatentKroneckerGP .construct_inputs (dataset )
612
+
613
+ # this test generates train_X and train_T in sorted order
614
+ # the data is randomly permuted before passing to construct_inputs
615
+ # construct_inputs sorts the data, so we expect the results to be equal
616
+ self .assertAllClose (model_inputs ["train_X" ], train_X , atol = 0.0 )
617
+ self .assertAllClose (model_inputs ["train_T" ], train_T , atol = 0.0 )
618
+ self .assertAllClose (
619
+ model_inputs ["train_Y" ], train_Y , atol = 0.0 , equal_nan = True
620
+ )
0 commit comments