Skip to content

Commit ce4900c

Browse files
dme65facebook-github-bot
authored andcommitted
Fix train_Yvar reshaping in SAASBO (#1183)
Summary: Pull Request resolved: #1183 The fixed noise should have shape `(b) x N` but was passed in as `N x 1`. Reviewed By: Balandat Differential Revision: D35702904 fbshipit-source-id: 5ee1887f36a0e83971e1085d8e9ac9175e0b0d36
1 parent 8a60162 commit ce4900c

File tree

3 files changed

+785
-756
lines changed

3 files changed

+785
-756
lines changed

botorch/models/fully_bayesian.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,11 @@ def load_mcmc_samples(
228228
).to(**tkwargs)
229229
if self.train_Yvar is not None:
230230
likelihood = FixedNoiseGaussianLikelihood(
231-
noise=self.train_Yvar, batch_shape=batch_shape
231+
# Reshape to shape `num_mcmc_samples x N`
232+
noise=self.train_Yvar.squeeze(-1).expand(
233+
num_mcmc_samples, len(self.train_Yvar)
234+
),
235+
batch_shape=batch_shape,
232236
).to(**tkwargs)
233237
else:
234238
likelihood = GaussianLikelihood(

test/models/test_fully_bayesian.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,12 @@ def test_fit_model(self):
200200
if infer_noise:
201201
self.assertEqual(model.likelihood.noise.shape, torch.Size([3, 1]))
202202
else:
203-
self.assertEqual(model.likelihood.noise.shape, torch.Size([n, 1]))
203+
self.assertEqual(model.likelihood.noise.shape, torch.Size([3, n]))
204204
self.assertTrue(
205205
torch.allclose(
206-
train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL),
206+
train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)
207+
.squeeze(-1)
208+
.repeat(3, 1),
207209
model.likelihood.noise,
208210
)
209211
)
@@ -460,7 +462,9 @@ def test_load_samples(self):
460462
self.assertTrue(
461463
torch.allclose(
462464
model.likelihood.noise_covar.noise,
463-
train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL),
465+
train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)
466+
.squeeze(-1)
467+
.repeat(3, 1),
464468
)
465469
)
466470

0 commit comments

Comments
 (0)