Skip to content

Commit 394609e

Browse files
Balandatfacebook-github-bot
authored andcommitted
ModelList <-> BatchedModel converters (#187)
Summary: Pull Request resolved: #187 Adds converters between `BatchedMultiOutputGPyTorchModel` and `ModelListGP` and the reverse direction (if applicable). This is useful e.g. for fitting batched multi-output models with a lot of outputs, where jointly fitting the model can result in inferior model fits (due to the size of the resulting optimization problem). See stacked diff. This currently does **not** support the following: - `HeteroskedasticSingleTaskGP` - custom likelihoods for `SingleTaskGP` Reviewed By: sdaulton Differential Revision: D15982128 fbshipit-source-id: d966b8007144e8ca27c83483a2ff4639ebfe2304
1 parent 62c763a commit 394609e

File tree

2 files changed

+326
-0
lines changed

2 files changed

+326
-0
lines changed

botorch/models/converter.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#! /usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
r"""
6+
Utilities for converting between different models.
7+
"""
8+
9+
from copy import deepcopy
10+
11+
import torch
12+
from torch.nn import Module
13+
14+
from ..exceptions import UnsupportedError
15+
from .gp_regression import FixedNoiseGP, HeteroskedasticSingleTaskGP
16+
from .gpytorch import BatchedMultiOutputGPyTorchModel
17+
from .model_list_gp_regression import ModelListGP
18+
19+
20+
def _get_module(module: Module, name: str) -> Module:
21+
"""Recursively get a sub-module from a module.
22+
23+
Args:
24+
module: A `torch.nn.Module`.
25+
name: The name of the submodule to return, in the form of a period-delinated
26+
string: `sub_module.subsub_module.[...].leaf_module`.
27+
28+
Returns:
29+
The requested sub-module.
30+
31+
Example:
32+
>>> gp = SingleTaskGP(train_X, train_Y)
33+
>>> noise_prior = _get_module(gp, "likelihood.noise_covar.noise_prior")
34+
"""
35+
current = module
36+
if name != "":
37+
for a in name.split("."):
38+
current = getattr(current, a)
39+
return current
40+
41+
42+
def _check_compatibility(models: ModelListGP) -> None:
43+
"""Check if a ModelListGP can be converted."""
44+
# check that all submodules are of the same type
45+
for modn, mod in models[0].named_modules():
46+
mcls = mod.__class__
47+
if not all(isinstance(_get_module(m, modn), mcls) for m in models[1:]):
48+
raise UnsupportedError(
49+
"Sub-modules must be of the same type across models."
50+
)
51+
52+
# check that each model is a BatchedMultiOutputGPyTorchModel
53+
if not all(isinstance(m, BatchedMultiOutputGPyTorchModel) for m in models):
54+
raise UnsupportedError(
55+
"All models must be of type BatchedMultiOutputGPyTorchModel."
56+
)
57+
58+
# TODO: Add support for HeteroskedasticSingleTaskGP
59+
if any(isinstance(m, HeteroskedasticSingleTaskGP) for m in models):
60+
raise NotImplementedError(
61+
"Conversion of HeteroskedasticSingleTaskGP is currently unsupported."
62+
)
63+
64+
# if the list has only one model, we can just return a copy of that
65+
if len(models) == 1:
66+
return deepcopy(models[0])
67+
68+
# check that each model is single-output
69+
if not all(m._num_outputs == 1 for m in models):
70+
raise UnsupportedError("All models must be single-output.")
71+
72+
# check that training inputs are the same
73+
if not all(
74+
torch.equal(ti, tj)
75+
for m in models[1:]
76+
for ti, tj in zip(models[0].train_inputs, m.train_inputs)
77+
):
78+
raise UnsupportedError("training inputs must agree for all sub-models.")
79+
80+
81+
def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorchModel:
82+
"""Convert a ModelListGP to a BatchedMultiOutputGPyTorchModel.
83+
84+
Args:
85+
model_list: The `ModelListGP` to be converted to the appropriate
86+
`BatchedMultiOutputGPyTorchModel`. All sub-models must be of the same
87+
type and have the shape (batch shape and number of training inputs).
88+
89+
Returns:
90+
The model converted into a `BatchedMultiOutputGPyTorchModel`.
91+
92+
Example:
93+
>>> list_gp = ModelListGP(gp1, gp2)
94+
>>> batch_gp = model_list_to_batched(list_gp)
95+
"""
96+
models = model_list.models
97+
_check_compatibility(models)
98+
99+
# construct inputs
100+
train_X = deepcopy(models[0].train_inputs[0])
101+
train_Y = torch.stack([m.train_targets.clone() for m in models], dim=-1)
102+
kwargs = {"train_X": train_X, "train_Y": train_Y}
103+
if isinstance(models[0], FixedNoiseGP):
104+
kwargs["train_Yvar"] = torch.stack(
105+
[m.likelihood.noise_covar.noise.clone() for m in models], dim=-1
106+
)
107+
108+
# construct the batched GP model
109+
batch_gp = models[0].__class__(**kwargs)
110+
111+
tensors = {n for n, p in batch_gp.state_dict().items() if len(p.shape) > 0}
112+
scalars = set(batch_gp.state_dict()) - tensors
113+
input_batch_dims = len(models[0]._input_batch_shape)
114+
115+
# ensure scalars agree (TODO: Allow different priors for different outputs)
116+
for n in scalars:
117+
v0 = _get_module(models[0], n)
118+
if not all(torch.equal(_get_module(m, n), v0) for m in models[1:]):
119+
raise UnsupportedError("All scalars must have the same value.")
120+
121+
# ensure dimensions of all tensors agree
122+
for n in tensors:
123+
shape0 = _get_module(models[0], n).shape
124+
if not all(_get_module(m, n).shape == shape0 for m in models[1:]):
125+
raise UnsupportedError("All tensors must have the same shape.")
126+
127+
# now construct the batched state dict
128+
scalar_state_dict = {
129+
s: p.clone() for s, p in models[0].state_dict().items() if s in scalars
130+
}
131+
tensor_state_dict = {
132+
t: torch.stack(
133+
[m.state_dict()[t].clone() for m in models], dim=input_batch_dims
134+
)
135+
for t in tensors
136+
}
137+
batch_state_dict = {**scalar_state_dict, **tensor_state_dict}
138+
139+
# load the state dict into the new model
140+
batch_gp.load_state_dict(batch_state_dict)
141+
142+
return batch_gp
143+
144+
145+
def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> ModelListGP:
146+
"""Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.
147+
148+
Args:
149+
model_list: The `BatchedMultiOutputGPyTorchModel` to be converted to a
150+
`ModelListGP`.
151+
152+
Returns:
153+
The model converted into a `ModelListGP`.
154+
155+
Example:
156+
>>> train_X = torch.rand(5, 2)
157+
>>> train_Y = torch.rand(5, 2)
158+
>>> batch_gp = SingleTaskGP(train_X, train_Y)
159+
>>> list_gp = batched_to_model_list(batch_gp)
160+
"""
161+
batch_sd = batch_model.state_dict()
162+
163+
tensors = {n for n, p in batch_sd.items() if len(p.shape) > 0}
164+
scalars = set(batch_sd) - tensors
165+
input_bdims = len(batch_model._input_batch_shape)
166+
167+
models = []
168+
169+
for i in range(batch_model._num_outputs):
170+
scalar_sd = {s: batch_sd[s].clone() for s in scalars}
171+
tensor_sd = {t: batch_sd[t].select(input_bdims, i).clone() for t in tensors}
172+
sd = {**scalar_sd, **tensor_sd}
173+
kwargs = {
174+
"train_X": batch_model.train_inputs[0].select(input_bdims, i).clone(),
175+
"train_Y": batch_model.train_targets.select(input_bdims, i).clone(),
176+
}
177+
if isinstance(batch_model, FixedNoiseGP):
178+
noise_covar = batch_model.likelihood.noise_covar
179+
kwargs["train_Yvar"] = noise_covar.noise.select(input_bdims, i).clone()
180+
# TODO: Add support for HeteroskedasticSingleTaskGP
181+
if isinstance(batch_model, HeteroskedasticSingleTaskGP):
182+
raise NotImplementedError(
183+
"Conversion of HeteroskedasticSingleTaskGP currently not supported."
184+
)
185+
model = batch_model.__class__(**kwargs)
186+
model.load_state_dict(sd)
187+
models.append(model)
188+
189+
return ModelListGP(*models)

test/models/test_converter.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#! /usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
import unittest
6+
7+
import torch
8+
from botorch.exceptions import UnsupportedError
9+
from botorch.models import (
10+
FixedNoiseGP,
11+
HeteroskedasticSingleTaskGP,
12+
ModelListGP,
13+
SingleTaskGP,
14+
)
15+
from botorch.models.converter import batched_to_model_list, model_list_to_batched
16+
17+
from .test_gpytorch import SimpleGPyTorchModel
18+
19+
20+
class TestConverters(unittest.TestCase):
21+
def test_batched_to_model_list(self, cuda=False):
22+
device = torch.device("cuda") if cuda else torch.device("cpu")
23+
for dtype in (torch.float, torch.double):
24+
# test SingleTaskGP
25+
train_X = torch.rand(10, 2, device=device, dtype=dtype)
26+
train_Y1 = train_X.sum(dim=-1)
27+
train_Y2 = train_X[:, 0] - train_X[:, 1]
28+
train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
29+
batch_gp = SingleTaskGP(train_X, train_Y)
30+
list_gp = batched_to_model_list(batch_gp)
31+
self.assertIsInstance(list_gp, ModelListGP)
32+
# test FixedNoiseGP
33+
batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
34+
list_gp = batched_to_model_list(batch_gp)
35+
self.assertIsInstance(list_gp, ModelListGP)
36+
# test HeteroskedasticSingleTaskGP
37+
batch_gp = HeteroskedasticSingleTaskGP(
38+
train_X, train_Y, torch.rand_like(train_Y)
39+
)
40+
with self.assertRaises(NotImplementedError):
41+
batched_to_model_list(batch_gp)
42+
43+
def test_batched_to_model_list_cuda(self):
44+
if torch.cuda.is_available():
45+
self.test_batched_to_model_list(cuda=True)
46+
47+
def test_model_list_to_batched(self, cuda=False):
48+
device = torch.device("cuda") if cuda else torch.device("cpu")
49+
for dtype in (torch.float, torch.double):
50+
# basic test
51+
train_X = torch.rand(10, 2, device=device, dtype=dtype)
52+
train_Y1 = train_X.sum(dim=-1)
53+
train_Y2 = train_X[:, 0] - train_X[:, 1]
54+
gp1 = SingleTaskGP(train_X, train_Y1)
55+
gp2 = SingleTaskGP(train_X, train_Y2)
56+
list_gp = ModelListGP(gp1, gp2)
57+
batch_gp = model_list_to_batched(list_gp)
58+
self.assertIsInstance(batch_gp, SingleTaskGP)
59+
# test degenerate (single model)
60+
batch_gp = model_list_to_batched(ModelListGP(gp1))
61+
self.assertEqual(batch_gp._num_outputs, 1)
62+
# test different model classes
63+
gp2 = FixedNoiseGP(train_X, train_Y1, torch.ones_like(train_Y1))
64+
with self.assertRaises(UnsupportedError):
65+
model_list_to_batched(ModelListGP(gp1, gp2))
66+
# test non-batched models
67+
gp1_ = SimpleGPyTorchModel(train_X, train_Y1)
68+
gp2_ = SimpleGPyTorchModel(train_X, train_Y2)
69+
with self.assertRaises(UnsupportedError):
70+
model_list_to_batched(ModelListGP(gp1_, gp2_))
71+
# test list of multi-output models
72+
train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
73+
gp2 = SingleTaskGP(train_X, train_Y)
74+
with self.assertRaises(UnsupportedError):
75+
model_list_to_batched(ModelListGP(gp1, gp2))
76+
# test different training inputs
77+
gp2 = SingleTaskGP(2 * train_X, train_Y2)
78+
with self.assertRaises(UnsupportedError):
79+
model_list_to_batched(ModelListGP(gp1, gp2))
80+
# check scalar agreement
81+
gp2 = SingleTaskGP(train_X, train_Y2)
82+
gp2.likelihood.noise_covar.noise_prior.rate.fill_(1.0)
83+
with self.assertRaises(UnsupportedError):
84+
model_list_to_batched(ModelListGP(gp1, gp2))
85+
# check tensor shape agreement
86+
gp2 = SingleTaskGP(train_X, train_Y2)
87+
gp2.covar_module.raw_outputscale = torch.nn.Parameter(
88+
torch.tensor([0.0], device=device, dtype=dtype)
89+
)
90+
with self.assertRaises(UnsupportedError):
91+
model_list_to_batched(ModelListGP(gp1, gp2))
92+
# test HeteroskedasticSingleTaskGP
93+
gp2 = HeteroskedasticSingleTaskGP(
94+
train_X, train_Y1, torch.ones_like(train_Y1)
95+
)
96+
with self.assertRaises(NotImplementedError):
97+
model_list_to_batched(ModelListGP(gp2))
98+
# test FixedNoiseGP
99+
train_X = torch.rand(10, 2, device=device, dtype=dtype)
100+
train_Y1 = train_X.sum(dim=-1)
101+
train_Y2 = train_X[:, 0] - train_X[:, 1]
102+
gp1_ = FixedNoiseGP(train_X, train_Y1, torch.rand_like(train_Y1))
103+
gp2_ = FixedNoiseGP(train_X, train_Y2, torch.rand_like(train_Y2))
104+
list_gp = ModelListGP(gp1_, gp2_)
105+
batch_gp = model_list_to_batched(list_gp)
106+
107+
def test_model_list_to_batched_cuda(self):
108+
if torch.cuda.is_available():
109+
self.test_model_list_to_batched(cuda=True)
110+
111+
def test_roundtrip(self, cuda=False):
112+
device = torch.device("cuda") if cuda else torch.device("cpu")
113+
for dtype in (torch.float, torch.double):
114+
train_X = torch.rand(10, 2, device=device, dtype=dtype)
115+
train_Y1 = train_X.sum(dim=-1)
116+
train_Y2 = train_X[:, 0] - train_X[:, 1]
117+
train_Y = torch.stack([train_Y1, train_Y2], dim=-1)
118+
# SingleTaskGP
119+
batch_gp = SingleTaskGP(train_X, train_Y)
120+
list_gp = batched_to_model_list(batch_gp)
121+
batch_gp_recov = model_list_to_batched(list_gp)
122+
sd_orig = batch_gp.state_dict()
123+
sd_recov = batch_gp_recov.state_dict()
124+
self.assertTrue(set(sd_orig) == set(sd_recov))
125+
self.assertTrue(all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig))
126+
# FixedNoiseGP
127+
batch_gp = FixedNoiseGP(train_X, train_Y, torch.rand_like(train_Y))
128+
list_gp = batched_to_model_list(batch_gp)
129+
batch_gp_recov = model_list_to_batched(list_gp)
130+
sd_orig = batch_gp.state_dict()
131+
sd_recov = batch_gp_recov.state_dict()
132+
self.assertTrue(set(sd_orig) == set(sd_recov))
133+
self.assertTrue(all(torch.equal(sd_orig[k], sd_recov[k]) for k in sd_orig))
134+
135+
def test_roundtrip_cuda(self):
136+
if torch.cuda.is_available():
137+
self.test_roundtrip(cuda=True)

0 commit comments

Comments
 (0)