Skip to content

Commit c426b0a

Browse files
Balandatfacebook-github-bot
authored andcommitted
Add subset_output functionality to (most) models (#324)
Summary: Pull Request resolved: #324 In some cases we want to be able to subset models along the output dimension. For instance, if we fit a multi-output model with a number of metrics, we may want to optimize an acquisition function with an objective that only involves a subset of the outputs. By subsetting the model prior to that, we can save a lot of compute. This diff adds a `subset_output` function to the model API. Calling this on a model with a list of indices will return a new model object that is restricted to the desired outputs. For some models (e.g `AffineDeterministicModel` or `ModelListGP`) the implementation is trivial. For others it's a little more involved but doable.The main challenge is with things like passing in generic covariance modules - we really don't have any way of knowing what dimensions of the respective buffers and parameters we need to subset / rescale in this case. Reviewed By: sdaulton Differential Revision: D18668985 fbshipit-source-id: 41479203f23e8bcfa08bbe5f025ed12f0124a091
1 parent d6d8591 commit c426b0a

File tree

11 files changed

+213
-3
lines changed

11 files changed

+213
-3
lines changed

botorch/models/deterministic.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
7373
self._f = f
7474
self._num_outputs = num_outputs
7575

76+
def subset_output(self, idcs: List[int]) -> "GenericDeterministicModel":
77+
r"""Subset the model along the output dimension.
78+
79+
Args:
80+
idcs: The output indices to subset the model to.
81+
82+
Returns:
83+
The current model, subset to the specified output indices.
84+
"""
85+
86+
def f_subset(X: Tensor) -> Tensor:
87+
return self._f(X)[..., idcs]
88+
89+
return self.__class__(f=f_subset)
90+
7691
def forward(self, X: Tensor) -> Tensor:
7792
r"""Compute the (deterministic) model output at X.
7893
@@ -113,5 +128,18 @@ def __init__(self, a: Tensor, b: Union[Tensor, float] = 0.01) -> None:
113128
self.register_buffer("b", b.expand(a.size(-1)))
114129
self._num_outputs = a.size(-1)
115130

131+
def subset_output(self, idcs: List[int]) -> "AffineDeterministicModel":
132+
r"""Subset the model along the output dimension.
133+
134+
Args:
135+
idcs: The output indices to subset the model to.
136+
137+
Returns:
138+
The current model, subset to the specified output indices.
139+
"""
140+
a_sub = self.a.detach()[..., idcs].clone()
141+
b_sub = self.b.detach()[..., idcs].clone()
142+
return self.__class__(a=a_sub, b=b_sub)
143+
116144
def forward(self, X: Tensor) -> Tensor:
117145
return self.b + torch.einsum("...d,dm", X, self.a)

botorch/models/gp_regression.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Gaussian Process Regression models based on GPyTorch models.
99
"""
1010

11-
from typing import Any, Optional, Union
11+
from typing import Any, List, Optional, Union
1212

1313
import torch
1414
from gpytorch.constraints.constraints import GreaterThan
@@ -117,8 +117,15 @@ def __init__(
117117
batch_shape=self._aug_batch_shape,
118118
outputscale_prior=GammaPrior(2.0, 0.15),
119119
)
120+
self._subset_batch_dict = {
121+
"likelihood.noise_covar.raw_noise": -2,
122+
"mean_module.constant": -2,
123+
"covar_module.raw_outputscale": -1,
124+
"covar_module.base_kernel.raw_lengthscale": -3,
125+
}
120126
else:
121127
self.covar_module = covar_module
128+
# TODO: Allow subsetting of other covar modules
122129
if outcome_transform is not None:
123130
self.outcome_transform = outcome_transform
124131
self.to(train_X)
@@ -192,6 +199,11 @@ def __init__(
192199
)
193200
if outcome_transform is not None:
194201
self.outcome_transform = outcome_transform
202+
self._subset_batch_dict = {
203+
"mean_module.constant": -2,
204+
"covar_module.raw_outputscale": -1,
205+
"covar_module.base_kernel.raw_lengthscale": -3,
206+
}
195207
self.to(train_X)
196208

197209
def fantasize(
@@ -242,6 +254,21 @@ def forward(self, x: Tensor) -> MultivariateNormal:
242254
covar_x = self.covar_module(x)
243255
return MultivariateNormal(mean_x, covar_x)
244256

257+
def subset_output(self, idcs: List[int]) -> "BatchedMultiOutputGPyTorchModel":
258+
r"""Subset the model along the output dimension.
259+
260+
Args:
261+
idcs: The output indices to subset the model to.
262+
263+
Returns:
264+
The current model, subset to the specified output indices.
265+
"""
266+
new_model = super().subset_output(idcs=idcs)
267+
full_noise = new_model.likelihood.noise_covar.noise
268+
new_noise = full_noise[..., idcs if len(idcs) > 1 else idcs[0], :]
269+
new_model.likelihood.noise_covar.noise = new_noise
270+
return new_model
271+
245272

246273
class HeteroskedasticSingleTaskGP(SingleTaskGP):
247274
r"""A single-task exact GP model using a heteroskeastic noise model.
@@ -311,3 +338,6 @@ def condition_on_observations(
311338
self, X: Tensor, Y: Tensor, **kwargs: Any
312339
) -> "HeteroskedasticSingleTaskGP":
313340
raise NotImplementedError
341+
342+
def subset_output(self, idcs: List[int]) -> "HeteroskedasticSingleTaskGP":
343+
raise NotImplementedError

botorch/models/gpytorch.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
GPyTorch Model class such as an ExactGP.
1212
"""
1313

14+
import itertools
1415
import warnings
1516
from abc import ABC
17+
from copy import deepcopy
1618
from typing import Any, Iterator, List, Optional, Tuple, Union
1719

1820
import torch
@@ -26,7 +28,12 @@
2628
from ..posteriors.gpytorch import GPyTorchPosterior
2729
from ..utils.transforms import gpt_posterior_settings
2830
from .model import Model
29-
from .utils import _make_X_full, add_output_dim, multioutput_to_batch_mode_transform
31+
from .utils import (
32+
_make_X_full,
33+
add_output_dim,
34+
mod_batch_shape,
35+
multioutput_to_batch_mode_transform,
36+
)
3037

3138

3239
class GPyTorchModel(Model, ABC):
@@ -358,6 +365,50 @@ def condition_on_observations(
358365
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
359366
return fantasy_model
360367

368+
def subset_output(self, idcs: List[int]) -> "BatchedMultiOutputGPyTorchModel":
369+
r"""Subset the model along the output dimension.
370+
371+
Args:
372+
idcs: The output indices to subset the model to.
373+
374+
Returns:
375+
The current model, subset to the specified output indices.
376+
"""
377+
try:
378+
subset_batch_dict = self._subset_batch_dict
379+
except AttributeError:
380+
raise NotImplementedError(
381+
"subset_output requires the model to define a `_subset_dict` attribute"
382+
)
383+
384+
m = len(idcs)
385+
tidxr = torch.tensor(idcs)
386+
idxr = tidxr if m > 1 else idcs[0]
387+
new_tail_bs = torch.Size([m]) if m > 1 else torch.Size()
388+
new_model = deepcopy(self)
389+
390+
new_model._num_outputs = m
391+
new_model._aug_batch_shape = new_model._aug_batch_shape[:-1] + new_tail_bs
392+
new_model.train_inputs = tuple(
393+
ti[..., idxr, :, :] for ti in new_model.train_inputs
394+
)
395+
new_model.train_targets = new_model.train_targets[..., idxr, :]
396+
397+
# adjust batch shapes of parameters/buffers if necessary
398+
for full_name, p in itertools.chain(
399+
new_model.named_parameters(), new_model.named_buffers()
400+
):
401+
if full_name in subset_batch_dict:
402+
idx = subset_batch_dict[full_name]
403+
new_data = p.index_select(idx, tidxr)
404+
if m == 1:
405+
new_data = new_data.squeeze(idx)
406+
p.data = new_data
407+
mod_name = full_name.split(".")[:-1]
408+
mod_batch_shape(new_model, mod_name, m if m > 1 else 0)
409+
410+
return new_model
411+
361412

362413
class ModelListGPyTorchModel(GPyTorchModel, ABC):
363414
r"""Abstract base class for models based on multi-output GPyTorch models.

botorch/models/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ def num_outputs(self) -> int:
5555
cls_name = self.__class__.__name__
5656
raise NotImplementedError(f"{cls_name} does not define num_outputs property")
5757

58+
def subset_output(self, idcs: List[int]) -> "Model":
59+
r"""Subset the model along the output dimension.
60+
61+
Args:
62+
idcs: The output indices to subset the model to.
63+
64+
Returns:
65+
A `Model` object of the same type and with the same parameters as
66+
the current model, subset to the specified output indices.
67+
"""
68+
raise NotImplementedError
69+
5870
def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> "Model":
5971
r"""Condition the model on new observations.
6072

botorch/models/model_list_gp_regression.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
Model List GP Regression models.
99
"""
1010

11-
from typing import Any
11+
from copy import deepcopy
12+
from typing import Any, List
1213

1314
from gpytorch.models import IndependentModelList
1415
from torch import Tensor
@@ -89,3 +90,14 @@ def condition_on_observations(
8990
else:
9091
kwargs_ = kwargs
9192
return super().get_fantasy_model(inputs, targets, **kwargs_)
93+
94+
def subset_output(self, idcs: List[int]) -> "ModelListGP":
95+
r"""Subset the model along the output dimension.
96+
97+
Args:
98+
idcs: The output indices to subset the model to.
99+
100+
Returns:
101+
The current model, subset to the specified output indices.
102+
"""
103+
return self.__class__(*[deepcopy(self.models[i]) for i in idcs])

botorch/models/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import List, Optional, Tuple
1313

1414
import torch
15+
from gpytorch.module import Module
1516
from gpytorch.utils.broadcasting import _mul_broadcast_shape
1617
from torch import Tensor
1718

@@ -222,3 +223,25 @@ def validate_input_scaling(
222223
raise InputDataError("Input data contains negative variances.")
223224
check_min_max_scaling(X=train_X, raise_on_fail=raise_on_fail)
224225
check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)
226+
227+
228+
def mod_batch_shape(module: Module, names: List[str], b: int) -> None:
229+
r"""Recursive helper to modify gpytorch modules' batch shape attribute.
230+
231+
Modifies the module in-place.
232+
233+
Args:
234+
module: The module to be modified.
235+
names: The list of names to access the attribute. If the full name of
236+
the module is `"module.sub_module.leaf_module"`, this will be
237+
`["sub_module", "leaf_module"]`.
238+
b: The new size of the last element of the module's `batch_shape`
239+
attribute.
240+
"""
241+
if len(names) == 0:
242+
return
243+
m = getattr(module, names[0])
244+
if len(names) == 1 and hasattr(m, "batch_shape") and len(m.batch_shape) > 0:
245+
m.batch_shape = m.batch_shape[:-1] + torch.Size([b] if b > 0 else [])
246+
else:
247+
mod_batch_shape(module=m, names=names[1:], b=b)

test/models/test_deterministic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def f(X):
3939
self.assertEqual(model.num_outputs, 2)
4040
p = model.posterior(X, output_indices=[0])
4141
self.assertTrue(torch.equal(p.mean, X[..., [0]]))
42+
# test subset output
43+
subset_model = model.subset_output([0])
44+
self.assertIsInstance(subset_model, GenericDeterministicModel)
45+
p_sub = subset_model.posterior(X)
46+
self.assertTrue(torch.equal(p_sub.mean, X[..., [0]]))
4247

4348
def test_AffineDeterministicModel(self):
4449
# test error on bad shape of a
@@ -65,3 +70,10 @@ def test_AffineDeterministicModel(self):
6570
p = model.posterior(X)
6671
mean_exp = model.b + (X.unsqueeze(-1) * a).sum(dim=-2)
6772
self.assertTrue(torch.equal(p.mean, mean_exp))
73+
# test subset output
74+
X = torch.rand(4, 3)
75+
subset_model = model.subset_output([0])
76+
self.assertIsInstance(subset_model, AffineDeterministicModel)
77+
p = model.posterior(X)
78+
p_sub = subset_model.posterior(X)
79+
self.assertTrue(torch.equal(p_sub.mean, p.mean[..., [0]]))

test/models/test_gp_regression.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,27 @@ def test_fantasize(self):
250250
fm = model.fantasize(X=X_f, sampler=sampler, observation_noise=False)
251251
self.assertIsInstance(fm, model.__class__)
252252

253+
def test_subset_model(self):
254+
for batch_shape, dtype in itertools.product(
255+
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
256+
):
257+
tkwargs = {"device": self.device, "dtype": dtype}
258+
model, model_kwargs = self._get_model_and_data(
259+
batch_shape=batch_shape, m=2, **tkwargs
260+
)
261+
subset_model = model.subset_output([0])
262+
X = torch.rand(torch.Size(batch_shape + torch.Size([3, 1])), **tkwargs)
263+
p = model.posterior(X)
264+
p_sub = subset_model.posterior(X)
265+
self.assertTrue(
266+
torch.allclose(p_sub.mean, p.mean[..., [0]], atol=1e-4, rtol=1e-4)
267+
)
268+
self.assertTrue(
269+
torch.allclose(
270+
p_sub.variance, p.variance[..., [0]], atol=1e-4, rtol=1e-4
271+
)
272+
)
273+
253274

254275
class TestFixedNoiseGP(TestSingleTaskGP):
255276
def _get_model_and_data(self, batch_shape, m, outcome_transform=None, **tkwargs):
@@ -324,6 +345,10 @@ def test_fantasize(self):
324345
with self.assertRaises(NotImplementedError):
325346
super().test_fantasize()
326347

348+
def test_subset_model(self):
349+
with self.assertRaises(NotImplementedError):
350+
super().test_subset_model()
351+
327352

328353
def _get_pvar_expected(posterior, model, X, m):
329354
lh_kwargs = {}

test/models/test_gpytorch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def test_gpytorch_model(self):
116116
)
117117
self.assertIsInstance(cm, SimpleGPyTorchModel)
118118
self.assertEqual(cm.train_targets.shape, torch.Size([7]))
119+
# test subset_output
120+
with self.assertRaises(NotImplementedError):
121+
model.subset_output([0])
119122
# test fantasize
120123
sampler = SobolQMCNormalSampler(num_samples=2)
121124
cm = model.fantasize(torch.rand(2, 1, **tkwargs), sampler=sampler)
@@ -191,6 +194,9 @@ def test_batched_multi_output_gpytorch_model(self):
191194
)
192195
self.assertIsInstance(posterior, GPyTorchPosterior)
193196
self.assertEqual(posterior.mean.shape, torch.Size([2, 2]))
197+
# test subset_output
198+
with self.assertRaises(NotImplementedError):
199+
model.subset_output([0])
194200
# test conditioning on observations
195201
cm = model.condition_on_observations(
196202
torch.rand(2, 1, **tkwargs), torch.rand(2, 2, **tkwargs)

test/models/test_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@ def test_not_so_abstract_base_model(self):
2424
model.condition_on_observations(None, None)
2525
with self.assertRaises(NotImplementedError):
2626
model.num_outputs
27+
with self.assertRaises(NotImplementedError):
28+
model.subset_output([0])

0 commit comments

Comments
 (0)