1010
1111import torch
1212from botorch import settings
13- from botorch .acquisition .cached_cholesky import (
14- CachedCholeskyMCAcquisitionFunction ,
15- CachedCholeskyMCSamplerMixin ,
16- )
13+ from botorch .acquisition .cached_cholesky import CachedCholeskyMCSamplerMixin
1714from botorch .acquisition .monte_carlo import MCAcquisitionFunction
1815from botorch .acquisition .objective import GenericMCObjective , MCAcquisitionObjective
1916from botorch .exceptions .warnings import BotorchWarning
@@ -49,81 +46,58 @@ def forward(self, X):
4946 return X
5047
5148
52- class DeprecatedCachedCholeskyAcqf (
53- MCAcquisitionFunction , CachedCholeskyMCAcquisitionFunction
54- ):
55- def __init__ (
56- self ,
57- model : Model ,
58- objective : Optional [MCAcquisitionObjective ] = None ,
59- sampler : Optional [MCSampler ] = None ,
60- cache_root : bool = False ,
61- ):
62- """A deprecated dummy cached cholesky acquisition function."""
63- MCAcquisitionFunction .__init__ (
64- self , model = model , objective = objective , sampler = sampler
65- )
66- self ._setup (model = model , cache_root = cache_root )
67-
68- def forward (self , X ):
69- return X
70-
71-
7249class TestCachedCholeskyMCSamplerMixin (BotorchTestCase ):
7350 def test_init (self ):
7451 mean = torch .zeros (1 , 1 )
7552 variance = torch .ones (1 , 1 )
7653 mm = MockModel (MockPosterior (mean = mean , variance = variance ))
7754 # basic test w/ invalid model.
7855 sampler = IIDNormalSampler (sample_shape = torch .Size ([1 ]))
79- with self .assertWarns (DeprecationWarning ):
80- DeprecatedCachedCholeskyAcqf (model = mm , sampler = sampler )
81-
82- constructors = [DeprecatedCachedCholeskyAcqf , DummyCachedCholeskyAcqf ]
83- for constructor in constructors :
84- acqf = constructor (model = mm , sampler = sampler )
85- self .assertFalse (acqf ._cache_root ) # no cache by default
86- with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
87- acqf = constructor (model = mm , sampler = sampler , cache_root = True )
88- self .assertFalse (acqf ._cache_root ) # gets turned to False
89- # Unsupported outcome transform.
90- stgp = SingleTaskGP (
91- torch .zeros (1 , 1 ), torch .zeros (1 , 1 ), outcome_transform = Log ()
56+
57+ acqf = DummyCachedCholeskyAcqf (model = mm , sampler = sampler )
58+ self .assertFalse (acqf ._cache_root ) # no cache by default
59+ with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
60+ acqf = DummyCachedCholeskyAcqf (model = mm , sampler = sampler , cache_root = True )
61+ self .assertFalse (acqf ._cache_root ) # gets turned to False
62+ # Unsupported outcome transform.
63+ stgp = SingleTaskGP (
64+ torch .zeros (1 , 1 ), torch .zeros (1 , 1 ), outcome_transform = Log ()
65+ )
66+ with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
67+ acqf = DummyCachedCholeskyAcqf (model = stgp , cache_root = True )
68+ self .assertFalse (acqf ._cache_root )
69+ # ModelList is not supported.
70+ model_list = ModelList (SingleTaskGP (torch .zeros (1 , 1 ), torch .zeros (1 , 1 )))
71+ with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
72+ acqf = DummyCachedCholeskyAcqf (model = model_list , cache_root = True )
73+ self .assertFalse (acqf ._cache_root )
74+
75+ # basic test w/ supported model.
76+ stgp = SingleTaskGP (torch .zeros (1 , 1 ), torch .zeros (1 , 1 ))
77+ acqf = DummyCachedCholeskyAcqf (model = stgp , sampler = sampler , cache_root = True )
78+ self .assertTrue (acqf ._cache_root )
79+ self .assertEqual (acqf .sampler , sampler )
80+
81+ # test the base_samples are set to None
82+ self .assertIsNone (acqf .sampler .base_samples )
83+ # test model that uses matheron's rule and sampler.batch_range != (0, -1)
84+ hogp = HigherOrderGP (torch .zeros (1 , 1 ), torch .zeros (1 , 1 , 1 )).eval ()
85+ with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
86+ acqf = DummyCachedCholeskyAcqf (model = hogp , sampler = sampler , cache_root = True )
87+ self .assertFalse (acqf ._cache_root )
88+
89+ # test deterministic model
90+ model = GenericDeterministicModel (f = lambda X : X )
91+ with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
92+ acqf = DummyCachedCholeskyAcqf (
93+ model = model , sampler = sampler , cache_root = True
9294 )
93- with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
94- acqf = constructor (model = stgp , cache_root = True )
95- self .assertFalse (acqf ._cache_root )
96- # ModelList is not supported.
97- model_list = ModelList (SingleTaskGP (torch .zeros (1 , 1 ), torch .zeros (1 , 1 )))
98- with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
99- acqf = constructor (model = model_list , cache_root = True )
100- self .assertFalse (acqf ._cache_root )
101-
102- # basic test w/ supported model.
103- stgp = SingleTaskGP (torch .zeros (1 , 1 ), torch .zeros (1 , 1 ))
104- acqf = constructor (model = stgp , sampler = sampler , cache_root = True )
105- self .assertTrue (acqf ._cache_root )
106- self .assertEqual (acqf .sampler , sampler )
107-
108- # test the base_samples are set to None
109- self .assertIsNone (acqf .sampler .base_samples )
110- # test model that uses matheron's rule and sampler.batch_range != (0, -1)
111- hogp = HigherOrderGP (torch .zeros (1 , 1 ), torch .zeros (1 , 1 , 1 )).eval ()
112- with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
113- acqf = constructor (model = hogp , sampler = sampler , cache_root = True )
114- self .assertFalse (acqf ._cache_root )
115-
116- # test deterministic model
117- model = GenericDeterministicModel (f = lambda X : X )
118- with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
119- acqf = constructor (model = model , sampler = sampler , cache_root = True )
120- self .assertFalse (acqf ._cache_root )
95+ self .assertFalse (acqf ._cache_root )
12196
12297 def test_cache_root_decomposition (self ):
12398 tkwargs = {"device" : self .device }
124- constructors = [DeprecatedCachedCholeskyAcqf , DummyCachedCholeskyAcqf ]
125- for constructor in constructors :
126- for dtype in (torch .float , torch .double ):
99+ for dtype in (torch .float , torch .double ):
100+ with self .subTest (dtype = dtype ):
127101 tkwargs ["dtype" ] = dtype
128102 # test mt-mvn
129103 train_x = torch .rand (2 , 1 , ** tkwargs )
@@ -133,7 +107,7 @@ def test_cache_root_decomposition(self):
133107 sampler = IIDNormalSampler (sample_shape = torch .Size ([1 ]))
134108 with torch .no_grad ():
135109 posterior = model .posterior (test_x )
136- acqf = constructor (
110+ acqf = DummyCachedCholeskyAcqf (
137111 model = model ,
138112 sampler = sampler ,
139113 objective = GenericMCObjective (lambda Y , _ : Y [..., 0 ]),
@@ -169,9 +143,8 @@ def test_cache_root_decomposition(self):
169143
170144 def test_get_f_X_samples (self ):
171145 tkwargs = {"device" : self .device }
172- constructors = [DeprecatedCachedCholeskyAcqf , DummyCachedCholeskyAcqf ]
173- for constructor in constructors :
174- for dtype in (torch .float , torch .double ):
146+ for dtype in (torch .float , torch .double ):
147+ with self .subTest (dtype = dtype ):
175148 tkwargs ["dtype" ] = dtype
176149 mean = torch .zeros (5 , 1 , ** tkwargs )
177150 variance = torch .ones (5 , 1 , ** tkwargs )
@@ -186,7 +159,9 @@ def test_get_f_X_samples(self):
186159 sampler = IIDNormalSampler (sample_shape = torch .Size ([1 ]))
187160
188161 with self .assertWarnsRegex (RuntimeWarning , "cache_root" ):
189- acqf = constructor (model = mm , sampler = sampler , cache_root = True )
162+ acqf = DummyCachedCholeskyAcqf (
163+ model = mm , sampler = sampler , cache_root = True
164+ )
190165 self .assertFalse (acqf ._cache_root )
191166 acqf ._cache_root = True
192167 q = 3
@@ -233,7 +208,9 @@ def test_get_f_X_samples(self):
233208 self .assertTrue (samples .shape , torch .Size ([1 , q , 1 ]))
234209 # test HOGP
235210 hogp = HigherOrderGP (torch .zeros (2 , 1 ), torch .zeros (2 , 1 , 1 )).eval ()
236- acqf = constructor (model = hogp , sampler = sampler , cache_root = True )
211+ acqf = DummyCachedCholeskyAcqf (
212+ model = hogp , sampler = sampler , cache_root = True
213+ )
237214 mock_samples = torch .rand (5 , 1 , 1 , ** tkwargs )
238215 posterior = MockPosterior (
239216 mean = mean , variance = variance , samples = mock_samples
0 commit comments