Skip to content

Commit cee0f64

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Fixing model_list_to_batched ignoring the covar_module of the input models (#1419)
Summary: Pull Request resolved: #1419 This diff fixes the bug reported [here](#1411). Reviewed By: Balandat Differential Revision: D39781851 fbshipit-source-id: f1a9b0b12438248036d6476c063346128052b87d
1 parent 15f23cc commit cee0f64

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

botorch/models/converter.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,18 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
150150
raise UnsupportedError("All models must have the same fidelity parameters.")
151151
kwargs.update(init_args)
152152

153+
# add batched kernel, except if the model type is SingleTaskMultiFidelityGP,
154+
# which does not have a `covar_module`
155+
if not isinstance(models[0], SingleTaskMultiFidelityGP):
156+
batch_length = len(models)
157+
covar_module = _batched_kernel(models[0].covar_module, batch_length)
158+
kwargs["covar_module"] = covar_module
159+
153160
# construct the batched GP model
154161
input_transform = getattr(models[0], "input_transform", None)
155162
if input_transform is not None:
156163
input_transform.train()
164+
157165
batch_gp = models[0].__class__(input_transform=input_transform, **kwargs)
158166
adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
159167
batch_state_dict=batch_gp.state_dict(), input_transform=input_transform
@@ -196,6 +204,46 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
196204
return batch_gp
197205

198206

207+
def _batched_kernel(kernel, batch_length: int):
208+
"""Adds a batch dimension of size `batch_length` to all non-scalar
209+
Tensor parameters that govern the kernel function `kernel`.
210+
NOTE: prior or constraint parameters are excluded from batching.
211+
"""
212+
# copy just in case there are non-tensor parameters that are passed by reference
213+
kernel = deepcopy(kernel)
214+
search_str = "raw_outputscale"
215+
for key, attr in kernel.state_dict().items():
216+
if isinstance(attr, Tensor) and (
217+
attr.ndim > 0 or (search_str == key.rpartition(".")[-1])
218+
):
219+
attr = attr.unsqueeze(0).expand(batch_length, *attr.shape).clone()
220+
set_attribute(kernel, key, torch.nn.Parameter(attr))
221+
return kernel
222+
223+
224+
# two helper functions for `batched_kernel`
225+
# like `setattr` and `getattr` for object hierarchies
226+
def set_attribute(obj, attr: str, val):
227+
"""Like `setattr` but works with hierarchical attribute specification.
228+
E.g. if obj=Zoo(), and attr="tiger.age", set_attribute(obj, attr, 3),
229+
would set the Zoo's tiger's age to three.
230+
"""
231+
path_to_leaf, _, attr_name = attr.rpartition(".")
232+
leaf = get_attribute(obj, path_to_leaf) if path_to_leaf else obj
233+
setattr(leaf, attr_name, val)
234+
235+
236+
def get_attribute(obj, attr: str):
237+
"""Like `getattr` but works with hierarchical attribute specification.
238+
E.g. if obj=Zoo(), and attr="tiger.age", get_attribute(obj, attr),
239+
would return the Zoo's tiger's age.
240+
"""
241+
attr_names = attr.split(".")
242+
while attr_names:
243+
obj = getattr(obj, attr_names.pop(0))
244+
return obj
245+
246+
199247
def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP:
200248
"""Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.
201249

test/models/test_converter.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from botorch.models.transforms.input import AppendFeatures, Normalize
2323
from botorch.models.transforms.outcome import Standardize
2424
from botorch.utils.testing import BotorchTestCase
25+
from gpytorch.kernels import RBFKernel
2526
from gpytorch.likelihoods import GaussianLikelihood
2627

2728
from .test_gpytorch import SimpleGPyTorchModel
@@ -145,6 +146,18 @@ def test_model_list_to_batched(self):
145146
gp2 = SingleTaskGP(train_X, train_Y2, likelihood=GaussianLikelihood())
146147
with self.assertRaises(NotImplementedError):
147148
model_list_to_batched(ModelListGP(gp2))
149+
# test non-default kernel
150+
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel())
151+
gp2 = SingleTaskGP(train_X, train_Y2, covar_module=RBFKernel())
152+
list_gp = ModelListGP(gp1, gp2)
153+
batch_gp = model_list_to_batched(list_gp)
154+
self.assertEqual(type(batch_gp.covar_module), RBFKernel)
155+
# test error when component GPs have different kernel types
156+
gp1 = SingleTaskGP(train_X, train_Y1, covar_module=RBFKernel())
157+
gp2 = SingleTaskGP(train_X, train_Y2)
158+
list_gp = ModelListGP(gp1, gp2)
159+
with self.assertRaises(UnsupportedError):
160+
model_list_to_batched(list_gp)
148161
# test FixedNoiseGP
149162
train_X = torch.rand(10, 2, device=self.device, dtype=dtype)
150163
train_Y1 = train_X.sum(dim=-1, keepdim=True)

0 commit comments

Comments
 (0)