Skip to content

Commit 6e8df86

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix RFF bug when using FixedNoiseGP models (#1528)
Summary: Pull Request resolved: #1528 `_model.likelihood.noise` is a `... x d`-dim tensor for fixed noise models with `d`-dimensional `train_X`. This leads to errors when constructing the RFF. This updates the `sigma_sq` to use the mean across the input dimensions, which was used prior to #1336. Reviewed By: Balandat Differential Revision: D41568680 fbshipit-source-id: fef740ebd62e49d63397ebc4ea1d1dc883f798aa
1 parent 4f18642 commit 6e8df86

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

botorch/utils/gp_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def get_gp_samples(
462462
phi_X = basis(train_X)
463463
# Sample weights from bayesian linear model.
464464
# weights.sample().shape == (n_samples, batch_shape_input, num_rff_features)
465-
sigma_sq = _model.likelihood.noise
465+
sigma_sq = _model.likelihood.noise.mean(dim=-1, keepdim=True)
466466
if len(basis.kernel_batch_shape) > 0:
467467
sigma_sq = sigma_sq.unsqueeze(-2)
468468
mvn = get_weights_posterior(

test/utils/test_gp_sampling.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from botorch.models.converter import batched_to_model_list
1313
from botorch.models.deterministic import DeterministicModel
14-
from botorch.models.gp_regression import SingleTaskGP
14+
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
1515
from botorch.models.model import ModelList
1616
from botorch.models.multitask import MultiTaskGP
1717
from botorch.models.transforms.input import Normalize
@@ -644,3 +644,20 @@ def test_get_gp_samples(self):
644644
expected = torch.Size([13, 5, 3, m])
645645
Y_batched = gp_samples.posterior(test_X).mean
646646
self.assertEqual(Y_batched.shape, expected)
647+
648+
def test_with_fixed_noise(self):
649+
for n_samples in (1, 20):
650+
gp_samples = get_gp_samples(
651+
model=FixedNoiseGP(
652+
torch.rand(5, 3, dtype=torch.double),
653+
torch.randn(5, 1, dtype=torch.double),
654+
torch.rand(5, 1, dtype=torch.double) * 0.1,
655+
),
656+
num_outputs=1,
657+
n_samples=n_samples,
658+
)
659+
samples = gp_samples(torch.rand(2, 3))
660+
expected_shape = (
661+
torch.Size([2, 1]) if n_samples == 1 else torch.Size([n_samples, 2, 1])
662+
)
663+
self.assertEqual(samples.shape, expected_shape)

0 commit comments

Comments
 (0)