Skip to content

Commit bf529df

Browse files
Jihao Andreas Linfacebook-github-bot
authored andcommitted
Removed deprecated CachedCholeskyMCAcquisitionFunction (#2399)
Summary: Pull Request resolved: #2399 Removed deprecated 'CachedCholeskyMCAcquisitionFunction' (previously replaced by 'CachedCholeskyMCSamplerMixin') from 'botorch/acquisition/cached_cholesky.py', and updated test cases in 'test/acquisition/test_cached_cholesky.py' accordingly. Reviewed By: SebastianAment Differential Revision: D59169287 fbshipit-source-id: 0a7d46f6e4150cff471b932526b7911f189a966f
1 parent 8fe69b8 commit bf529df

File tree

2 files changed

+51
-99
lines changed

2 files changed

+51
-99
lines changed

botorch/acquisition/cached_cholesky.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -186,28 +186,3 @@ def _set_sampler(
186186
posterior=posterior, base_sampler=self.base_sampler
187187
)
188188
self.q_in = q_in
189-
190-
191-
# TODO: remove
192-
class CachedCholeskyMCAcquisitionFunction(CachedCholeskyMCSamplerMixin):
193-
r"""DEPRECATED - USE CachedCholeskyMCSamplerMixin instead."""
194-
195-
def _setup(
196-
self,
197-
model: Model,
198-
cache_root: bool = False,
199-
) -> None:
200-
r"""Set class attributes and perform compatibility checks.
201-
202-
Args:
203-
model: A model.
204-
cache_root: A boolean indicating whether to cache the Cholesky.
205-
This might be overridden in the model is not compatible.
206-
"""
207-
warnings.warn(
208-
"`CachedCholeskyMCAcquisitionFunction` is deprecated. Please switch to "
209-
"`CachedCholeskyMCSamplerMixin` and replace any calls to _setup with the "
210-
"constructor of the Mixin class.",
211-
DeprecationWarning,
212-
)
213-
super().__init__(model=model, cache_root=cache_root, sampler=self.sampler)

test/acquisition/test_cached_cholesky.py

Lines changed: 51 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010

1111
import torch
1212
from botorch import settings
13-
from botorch.acquisition.cached_cholesky import (
14-
CachedCholeskyMCAcquisitionFunction,
15-
CachedCholeskyMCSamplerMixin,
16-
)
13+
from botorch.acquisition.cached_cholesky import CachedCholeskyMCSamplerMixin
1714
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
1815
from botorch.acquisition.objective import GenericMCObjective, MCAcquisitionObjective
1916
from botorch.exceptions.warnings import BotorchWarning
@@ -49,81 +46,58 @@ def forward(self, X):
4946
return X
5047

5148

52-
class DeprecatedCachedCholeskyAcqf(
53-
MCAcquisitionFunction, CachedCholeskyMCAcquisitionFunction
54-
):
55-
def __init__(
56-
self,
57-
model: Model,
58-
objective: Optional[MCAcquisitionObjective] = None,
59-
sampler: Optional[MCSampler] = None,
60-
cache_root: bool = False,
61-
):
62-
"""A deprecated dummy cached cholesky acquisition function."""
63-
MCAcquisitionFunction.__init__(
64-
self, model=model, objective=objective, sampler=sampler
65-
)
66-
self._setup(model=model, cache_root=cache_root)
67-
68-
def forward(self, X):
69-
return X
70-
71-
7249
class TestCachedCholeskyMCSamplerMixin(BotorchTestCase):
7350
def test_init(self):
7451
mean = torch.zeros(1, 1)
7552
variance = torch.ones(1, 1)
7653
mm = MockModel(MockPosterior(mean=mean, variance=variance))
7754
# basic test w/ invalid model.
7855
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
79-
with self.assertWarns(DeprecationWarning):
80-
DeprecatedCachedCholeskyAcqf(model=mm, sampler=sampler)
81-
82-
constructors = [DeprecatedCachedCholeskyAcqf, DummyCachedCholeskyAcqf]
83-
for constructor in constructors:
84-
acqf = constructor(model=mm, sampler=sampler)
85-
self.assertFalse(acqf._cache_root) # no cache by default
86-
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
87-
acqf = constructor(model=mm, sampler=sampler, cache_root=True)
88-
self.assertFalse(acqf._cache_root) # gets turned to False
89-
# Unsupported outcome transform.
90-
stgp = SingleTaskGP(
91-
torch.zeros(1, 1), torch.zeros(1, 1), outcome_transform=Log()
56+
57+
acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler)
58+
self.assertFalse(acqf._cache_root) # no cache by default
59+
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
60+
acqf = DummyCachedCholeskyAcqf(model=mm, sampler=sampler, cache_root=True)
61+
self.assertFalse(acqf._cache_root) # gets turned to False
62+
# Unsupported outcome transform.
63+
stgp = SingleTaskGP(
64+
torch.zeros(1, 1), torch.zeros(1, 1), outcome_transform=Log()
65+
)
66+
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
67+
acqf = DummyCachedCholeskyAcqf(model=stgp, cache_root=True)
68+
self.assertFalse(acqf._cache_root)
69+
# ModelList is not supported.
70+
model_list = ModelList(SingleTaskGP(torch.zeros(1, 1), torch.zeros(1, 1)))
71+
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
72+
acqf = DummyCachedCholeskyAcqf(model=model_list, cache_root=True)
73+
self.assertFalse(acqf._cache_root)
74+
75+
# basic test w/ supported model.
76+
stgp = SingleTaskGP(torch.zeros(1, 1), torch.zeros(1, 1))
77+
acqf = DummyCachedCholeskyAcqf(model=stgp, sampler=sampler, cache_root=True)
78+
self.assertTrue(acqf._cache_root)
79+
self.assertEqual(acqf.sampler, sampler)
80+
81+
# test the base_samples are set to None
82+
self.assertIsNone(acqf.sampler.base_samples)
83+
# test model that uses matheron's rule and sampler.batch_range != (0, -1)
84+
hogp = HigherOrderGP(torch.zeros(1, 1), torch.zeros(1, 1, 1)).eval()
85+
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
86+
acqf = DummyCachedCholeskyAcqf(model=hogp, sampler=sampler, cache_root=True)
87+
self.assertFalse(acqf._cache_root)
88+
89+
# test deterministic model
90+
model = GenericDeterministicModel(f=lambda X: X)
91+
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
92+
acqf = DummyCachedCholeskyAcqf(
93+
model=model, sampler=sampler, cache_root=True
9294
)
93-
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
94-
acqf = constructor(model=stgp, cache_root=True)
95-
self.assertFalse(acqf._cache_root)
96-
# ModelList is not supported.
97-
model_list = ModelList(SingleTaskGP(torch.zeros(1, 1), torch.zeros(1, 1)))
98-
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
99-
acqf = constructor(model=model_list, cache_root=True)
100-
self.assertFalse(acqf._cache_root)
101-
102-
# basic test w/ supported model.
103-
stgp = SingleTaskGP(torch.zeros(1, 1), torch.zeros(1, 1))
104-
acqf = constructor(model=stgp, sampler=sampler, cache_root=True)
105-
self.assertTrue(acqf._cache_root)
106-
self.assertEqual(acqf.sampler, sampler)
107-
108-
# test the base_samples are set to None
109-
self.assertIsNone(acqf.sampler.base_samples)
110-
# test model that uses matheron's rule and sampler.batch_range != (0, -1)
111-
hogp = HigherOrderGP(torch.zeros(1, 1), torch.zeros(1, 1, 1)).eval()
112-
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
113-
acqf = constructor(model=hogp, sampler=sampler, cache_root=True)
114-
self.assertFalse(acqf._cache_root)
115-
116-
# test deterministic model
117-
model = GenericDeterministicModel(f=lambda X: X)
118-
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
119-
acqf = constructor(model=model, sampler=sampler, cache_root=True)
120-
self.assertFalse(acqf._cache_root)
95+
self.assertFalse(acqf._cache_root)
12196

12297
def test_cache_root_decomposition(self):
12398
tkwargs = {"device": self.device}
124-
constructors = [DeprecatedCachedCholeskyAcqf, DummyCachedCholeskyAcqf]
125-
for constructor in constructors:
126-
for dtype in (torch.float, torch.double):
99+
for dtype in (torch.float, torch.double):
100+
with self.subTest(dtype=dtype):
127101
tkwargs["dtype"] = dtype
128102
# test mt-mvn
129103
train_x = torch.rand(2, 1, **tkwargs)
@@ -133,7 +107,7 @@ def test_cache_root_decomposition(self):
133107
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
134108
with torch.no_grad():
135109
posterior = model.posterior(test_x)
136-
acqf = constructor(
110+
acqf = DummyCachedCholeskyAcqf(
137111
model=model,
138112
sampler=sampler,
139113
objective=GenericMCObjective(lambda Y, _: Y[..., 0]),
@@ -169,9 +143,8 @@ def test_cache_root_decomposition(self):
169143

170144
def test_get_f_X_samples(self):
171145
tkwargs = {"device": self.device}
172-
constructors = [DeprecatedCachedCholeskyAcqf, DummyCachedCholeskyAcqf]
173-
for constructor in constructors:
174-
for dtype in (torch.float, torch.double):
146+
for dtype in (torch.float, torch.double):
147+
with self.subTest(dtype=dtype):
175148
tkwargs["dtype"] = dtype
176149
mean = torch.zeros(5, 1, **tkwargs)
177150
variance = torch.ones(5, 1, **tkwargs)
@@ -186,7 +159,9 @@ def test_get_f_X_samples(self):
186159
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
187160

188161
with self.assertWarnsRegex(RuntimeWarning, "cache_root"):
189-
acqf = constructor(model=mm, sampler=sampler, cache_root=True)
162+
acqf = DummyCachedCholeskyAcqf(
163+
model=mm, sampler=sampler, cache_root=True
164+
)
190165
self.assertFalse(acqf._cache_root)
191166
acqf._cache_root = True
192167
q = 3
@@ -233,7 +208,9 @@ def test_get_f_X_samples(self):
233208
self.assertTrue(samples.shape, torch.Size([1, q, 1]))
234209
# test HOGP
235210
hogp = HigherOrderGP(torch.zeros(2, 1), torch.zeros(2, 1, 1)).eval()
236-
acqf = constructor(model=hogp, sampler=sampler, cache_root=True)
211+
acqf = DummyCachedCholeskyAcqf(
212+
model=hogp, sampler=sampler, cache_root=True
213+
)
237214
mock_samples = torch.rand(5, 1, 1, **tkwargs)
238215
posterior = MockPosterior(
239216
mean=mean, variance=variance, samples=mock_samples

0 commit comments

Comments
 (0)