Skip to content

Commit 77bd929

Browse files
Balandatfacebook-github-bot
authored andcommitted
Validate output shape of models upon instantiating acq. functions (#331)
Summary: Pull Request resolved: #331 Addresses #328 Reviewed By: liangshi7 Differential Revision: D18739840 fbshipit-source-id: 4b888d17a24de211315d36a09ce4e879d294973e
1 parent 8f7221f commit 77bd929

File tree

9 files changed

+259
-193
lines changed

9 files changed

+259
-193
lines changed

botorch/acquisition/analytic.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,37 +41,36 @@ def __init__(
4141
objective: A ScalarizedObjective (optional).
4242
"""
4343
super().__init__(model=model)
44-
if objective is not None and not isinstance(objective, ScalarizedObjective):
44+
if objective is None:
45+
if model.num_outputs != 1:
46+
raise UnsupportedError(
47+
"Must specify an objective when using a multi-output model."
48+
)
49+
elif not isinstance(objective, ScalarizedObjective):
4550
raise UnsupportedError(
4651
"Only objectives of type ScalarizedObjective are supported for "
4752
"analytic acquisition functions."
4853
)
4954
self.objective = objective
5055

51-
def _get_posterior(self, X: Tensor, check_single_output: bool = True) -> Posterior:
56+
def _get_posterior(self, X: Tensor) -> Posterior:
5257
r"""Compute the posterior at the input candidate set X.
5358
5459
Applies the objective if provided.
5560
5661
Args:
57-
X: The input candidate set
58-
check_single_output: If True, Raise an error if the posterior is not
59-
single-output.
62+
X: The input candidate set.
6063
6164
Returns:
62-
The posterior at X.
65+
The posterior at X. If a ScalarizedObjective is defined, this
66+
posterior can be single-output even if the underlying model is a
67+
multi-output model.
6368
"""
6469
posterior = self.model.posterior(X)
6570
if self.objective is not None:
66-
# Unlike MCAcquisitionObjectives (which transform samples), this
71+
# Unlike MCAcquisitionObjective (which transform samples), this
6772
# transforms the posterior
6873
posterior = self.objective(posterior)
69-
if check_single_output:
70-
if posterior.event_shape[-1] != 1:
71-
raise UnsupportedError(
72-
"Multi-Output posteriors are not supported for acquisition "
73-
f"functions of type {self.__class__.__name__}"
74-
)
7574
return posterior
7675

7776
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
@@ -349,7 +348,9 @@ def __init__(
349348
bounds on that output (resp. interpreted as -Inf / Inf if None)
350349
maximize: If True, consider the problem a maximization problem.
351350
"""
352-
super().__init__(model=model)
351+
# use AcquisitionFunction constructor to avoid check for objective
352+
super(AnalyticAcquisitionFunction, self).__init__(model=model)
353+
self.objective = None
353354
self.maximize = maximize
354355
self.objective_index = objective_index
355356
self.constraints = constraints
@@ -369,7 +370,7 @@ def forward(self, X: Tensor) -> Tensor:
369370
A `(b)`-dim Tensor of Expected Improvement values at the given
370371
design points `X`.
371372
"""
372-
posterior = self._get_posterior(X=X, check_single_output=False)
373+
posterior = self._get_posterior(X=X)
373374
means = posterior.mean.squeeze(dim=-2) # (b) x m
374375
sigmas = posterior.variance.squeeze(dim=-2).sqrt().clamp_min(1e-9) # (b) x m
375376

botorch/acquisition/max_value_entropy_search.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,6 @@ def __init__(
8989
sampler = SobolQMCNormalSampler(num_y_samples)
9090
super().__init__(model=model, sampler=sampler)
9191

92-
# The following check is specific to batched models
93-
# @TODO: model-agnostic num_outputs check: #295
94-
if self.model._num_outputs > 1:
95-
raise NotImplementedError(
96-
"Models with > 1 outcomes are not yet supported by qMaxValueEntropy!"
97-
)
98-
9992
# Batch GP models (e.g. fantasized models) are not currently supported
10093
if self.model.train_inputs[0].ndim > 2:
10194
raise NotImplementedError(

botorch/acquisition/monte_carlo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def __init__(
6565
sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)
6666
self.add_module("sampler", sampler)
6767
if objective is None:
68+
if model.num_outputs != 1:
69+
raise UnsupportedError(
70+
"Must specify an objective when using a multi-output model."
71+
)
6872
objective = IdentityMCObjective()
6973
elif not isinstance(objective, MCAcquisitionObjective):
7074
raise UnsupportedError(

test/acquisition/test_analytic.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,21 @@
3838
]
3939

4040

41+
class DummyAnalyticAcquisitionFunction(AnalyticAcquisitionFunction):
42+
def forward(self, X):
43+
pass
44+
45+
4146
class TestAnalyticAcquisitionFunction(BotorchTestCase):
4247
def test_abstract_raises(self):
4348
with self.assertRaises(TypeError):
4449
AnalyticAcquisitionFunction()
50+
# raise if model is multi-output, but no objective is given
51+
mean = torch.zeros(1, 2)
52+
variance = torch.ones(1, 2)
53+
mm = MockModel(MockPosterior(mean=mean, variance=variance))
54+
with self.assertRaises(UnsupportedError):
55+
DummyAnalyticAcquisitionFunction(model=mm)
4556

4657

4758
class TestExpectedImprovement(BotorchTestCase):
@@ -113,9 +124,8 @@ def test_expected_improvement_batch(self):
113124
mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
114125
variance2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
115126
mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
116-
module2 = ExpectedImprovement(model=mm2, best_f=0.0)
117127
with self.assertRaises(UnsupportedError):
118-
module2(X)
128+
ExpectedImprovement(model=mm2, best_f=0.0)
119129

120130
# test objective (single-output)
121131
mean = torch.tensor([[[0.5]], [[0.25]]], device=self.device, dtype=dtype)
@@ -172,9 +182,8 @@ def test_posterior_mean(self):
172182
# check for proper error if multi-output model
173183
mean2 = torch.rand(1, 2, device=self.device, dtype=dtype)
174184
mm2 = MockModel(MockPosterior(mean=mean2))
175-
module2 = PosteriorMean(model=mm2)
176185
with self.assertRaises(UnsupportedError):
177-
module2(X)
186+
PosteriorMean(model=mm2)
178187

179188
def test_posterior_mean_batch(self):
180189
for dtype in (torch.float, torch.double):
@@ -189,9 +198,8 @@ def test_posterior_mean_batch(self):
189198
# check for proper error if multi-output model
190199
mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
191200
mm2 = MockModel(MockPosterior(mean=mean2))
192-
module2 = PosteriorMean(model=mm2)
193201
with self.assertRaises(UnsupportedError):
194-
module2(X)
202+
PosteriorMean(model=mm2)
195203

196204

197205
class TestProbabilityOfImprovement(BotorchTestCase):
@@ -217,9 +225,8 @@ def test_probability_of_improvement(self):
217225
mean2 = torch.rand(1, 2, device=self.device, dtype=dtype)
218226
variance2 = torch.ones_like(mean2)
219227
mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
220-
module2 = ProbabilityOfImprovement(model=mm2, best_f=0.0)
221228
with self.assertRaises(UnsupportedError):
222-
module2(X)
229+
ProbabilityOfImprovement(model=mm2, best_f=0.0)
223230

224231
def test_probability_of_improvement_batch(self):
225232
for dtype in (torch.float, torch.double):
@@ -237,9 +244,8 @@ def test_probability_of_improvement_batch(self):
237244
mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
238245
variance2 = torch.ones_like(mean2)
239246
mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
240-
module2 = ProbabilityOfImprovement(model=mm2, best_f=0.0)
241247
with self.assertRaises(UnsupportedError):
242-
module2(X)
248+
ProbabilityOfImprovement(model=mm2, best_f=0.0)
243249

244250

245251
class TestUpperConfidenceBound(BotorchTestCase):
@@ -265,9 +271,8 @@ def test_upper_confidence_bound(self):
265271
mean2 = torch.rand(1, 2, device=self.device, dtype=dtype)
266272
variance2 = torch.rand(1, 2, device=self.device, dtype=dtype)
267273
mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
268-
module2 = UpperConfidenceBound(model=mm2, beta=1.0)
269274
with self.assertRaises(UnsupportedError):
270-
module2(X)
275+
UpperConfidenceBound(model=mm2, beta=1.0)
271276

272277
def test_upper_confidence_bound_batch(self):
273278
for dtype in (torch.float, torch.double):
@@ -287,9 +292,8 @@ def test_upper_confidence_bound_batch(self):
287292
mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
288293
variance2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
289294
mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
290-
module2 = UpperConfidenceBound(model=mm2, beta=1.0)
291295
with self.assertRaises(UnsupportedError):
292-
module2(X)
296+
UpperConfidenceBound(model=mm2, beta=1.0)
293297

294298

295299
class TestConstrainedExpectedImprovement(BotorchTestCase):

0 commit comments

Comments
 (0)