Skip to content

Commit fc2053b

Browse files
authored
LMC multitask-SVGP models can output a single task per input. (#1769)
* LMC multitask-SVGP models can output a single task per input. If one defines a ApproximateGP model with a LMCVariationalStrategy, there are now two different options for return types: 1. Calling `model(x)` will return a `... x N x num_tasks` MultitaskMultivariateNormal distribution 1. Calling `model(x, task_indices=i)` will return a `... x N` MultivariateNormal distribution, where `i` corresponds to the selected task index for each input. [Closes #1285, #1433] [Addresses #1743, #1765]
1 parent f06004e commit fc2053b

File tree

7 files changed

+527
-80
lines changed

7 files changed

+527
-80
lines changed

docs/source/variational.rst

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,23 @@ These are special :obj:`~gpytorch.variational._VariationalStrategy` objects that
9898
:obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distributions. Each of these objects
9999
acts on a batch of approximate GPs.
100100

101-
102-
:hidden:`IndependentMultitaskVariationalStrategy`
101+
:hidden:`LMCVariationalStrategy`
103102
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104103

105-
.. autoclass:: IndependentMultitaskVariationalStrategy
104+
.. autoclass:: LMCVariationalStrategy
106105
:members:
107106

108-
:hidden:`LMCVariationalStrategy`
107+
.. automethod:: __call__
108+
109+
110+
:hidden:`IndependentMultitaskVariationalStrategy`
109111
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
110112

111-
.. autoclass:: LMCVariationalStrategy
113+
.. autoclass:: IndependentMultitaskVariationalStrategy
112114
:members:
113115

116+
.. automethod:: __call__
117+
114118

115119
Variational Distributions
116120
-----------------------------

examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.ipynb

Lines changed: 72 additions & 6 deletions
Large diffs are not rendered by default.

gpytorch/variational/independent_multitask_variational_strategy.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22

33
import warnings
44

5-
from ..distributions import MultitaskMultivariateNormal
5+
import torch
6+
7+
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
8+
from ..lazy import RootLazyTensor
69
from ..module import Module
710
from ._variational_strategy import _VariationalStrategy
811

912

1013
class IndependentMultitaskVariationalStrategy(_VariationalStrategy):
1114
"""
1215
IndependentMultitaskVariationalStrategy wraps an existing
13-
:obj:`~gpytorch.variational.VariationalStrategy`
14-
to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution.
15-
All outputs will be independent of one another.
16+
:obj:`~gpytorch.variational.VariationalStrategy` to produce vector-valued (multi-task)
17+
output distributions. Each task will be independent of one another.
18+
19+
The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution
20+
(if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal`
21+
(if we wish to evaluate a single task for each input).
1622
1723
The base variational strategy is assumed to operate on a batch of GPs. One of the batch
1824
dimensions corresponds to the multiple tasks.
@@ -43,19 +49,46 @@ def variational_params_initialized(self):
4349
def kl_divergence(self):
4450
return super().kl_divergence().sum(dim=-1)
4551

46-
def __call__(self, x, prior=False, **kwargs):
52+
def __call__(self, x, task_indices=None, prior=False, **kwargs):
53+
r"""
54+
See :class:`LMCVariationalStrategy`.
55+
"""
4756
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
48-
if (
49-
self.task_dim > 0
50-
and self.task_dim > len(function_dist.batch_shape)
51-
or self.task_dim < 0
52-
and self.task_dim + len(function_dist.batch_shape) < 0
53-
):
54-
return MultitaskMultivariateNormal.from_repeated_mvn(function_dist, num_tasks=self.num_tasks)
57+
58+
if task_indices is None:
59+
# Every data point will get an output for each task
60+
if (
61+
self.task_dim > 0
62+
and self.task_dim > len(function_dist.batch_shape)
63+
or self.task_dim < 0
64+
and self.task_dim + len(function_dist.batch_shape) < 0
65+
):
66+
return MultitaskMultivariateNormal.from_repeated_mvn(function_dist, num_tasks=self.num_tasks)
67+
else:
68+
function_dist = MultitaskMultivariateNormal.from_batch_mvn(function_dist, task_dim=self.task_dim)
69+
assert function_dist.event_shape[-1] == self.num_tasks
70+
return function_dist
71+
5572
else:
56-
function_dist = MultitaskMultivariateNormal.from_batch_mvn(function_dist, task_dim=self.task_dim)
57-
assert function_dist.event_shape[-1] == self.num_tasks
58-
return function_dist
73+
# Each data point will get a single output corresponding to a single task
74+
75+
if self.task_dim > 0:
76+
raise RuntimeError(f"task_dim must be a negative indexed batch dimension: got {self.task_dim}.")
77+
num_batch = len(function_dist.batch_shape)
78+
task_dim = num_batch + self.task_dim
79+
80+
# Create a mask to choose specific task assignment
81+
shape = list(function_dist.batch_shape + function_dist.event_shape)
82+
shape[task_dim] = 1
83+
task_indices = task_indices.expand(shape).squeeze(task_dim)
84+
85+
# Create a mask to choose specific task assignment
86+
task_mask = torch.nn.functional.one_hot(task_indices, num_classes=self.num_tasks)
87+
task_mask = task_mask.permute(*range(0, task_dim), *range(task_dim + 1, num_batch + 1), task_dim)
88+
89+
mean = (function_dist.mean * task_mask).sum(task_dim)
90+
covar = (function_dist.lazy_covariance_matrix * RootLazyTensor(task_mask[..., None])).sum(task_dim)
91+
return MultivariateNormal(mean, covar)
5992

6093

6194
class MultitaskVariationalStrategy(IndependentMultitaskVariationalStrategy):

gpytorch/variational/lmc_variational_strategy.py

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,34 @@
22

33
import torch
44

5-
from ..distributions import MultitaskMultivariateNormal
6-
from ..lazy import KroneckerProductLazyTensor, MatmulLazyTensor
5+
from .. import settings
6+
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
7+
from ..lazy import KroneckerProductLazyTensor, RootLazyTensor
78
from ..module import Module
9+
from ..utils.broadcasting import _mul_broadcast_shape
10+
from ..utils.interpolation import left_interp
811
from ._variational_strategy import _VariationalStrategy
912

1013

14+
def _select_lmc_coefficients(lmc_coefficients: torch.Tensor, indices: torch.LongTensor) -> torch.Tensor:
15+
"""
16+
Given a list of indices for ... x N datapoints,
17+
select the row from lmc_coefficient that corresponds to each datapoint
18+
19+
lmc_coefficients: torch.Tensor ... x num_latents x ... x num_tasks
20+
indices: torch.Tesnor ... x N
21+
"""
22+
batch_shape = _mul_broadcast_shape(lmc_coefficients.shape[:-1], indices.shape[:-1])
23+
24+
# We will use the left_interp helper to do the indexing
25+
lmc_coefficients = lmc_coefficients.expand(*batch_shape, lmc_coefficients.shape[-1])[..., None]
26+
indices = indices.expand(*batch_shape, indices.shape[-1])[..., None]
27+
res = left_interp(
28+
indices, torch.ones(indices.shape, dtype=torch.long, device=indices.device), lmc_coefficients,
29+
).squeeze(-1)
30+
return res
31+
32+
1133
class LMCVariationalStrategy(_VariationalStrategy):
1234
r"""
1335
LMCVariationalStrategy is an implementation of the "Linear Model of Coregionalization"
@@ -20,8 +42,11 @@ class LMCVariationalStrategy(_VariationalStrategy):
2042
2143
f_{\text{task } i}( \mathbf x) = \sum_{q=1}^Q a_i^{(q)} g^{(q)} ( \mathbf x )
2244
23-
LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy`
24-
to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution.
45+
LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy`.
46+
The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution
47+
(if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal`
48+
(if we wish to evaluate a single task for each input).
49+
2550
The base variational strategy is assumed to operate on a multi-batch of GPs, where one
2651
of the batch dimensions corresponds to the latent function dimension.
2752
@@ -35,13 +60,6 @@ class LMCVariationalStrategy(_VariationalStrategy):
3560
batch shape. This would correspond to each of the latent functions having different kernels
3661
or the same kernel, respectivly.
3762
38-
:param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
39-
:param int num_tasks: The total number of tasks (output functions)
40-
:param int num_latents: The total number of latent functions in each group
41-
:param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
42-
**Must be negative indexed**
43-
:type latent_dim: `int` < 0
44-
4563
Example:
4664
>>> class LMCMultitaskGP(gpytorch.models.ApproximateGP):
4765
>>> '''
@@ -74,7 +92,13 @@ class LMCVariationalStrategy(_VariationalStrategy):
7492
>>> batch_shape=torch.Size([3]),
7593
>>> )
7694
>>>
77-
>>> # Model output: n x 5
95+
96+
:param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
97+
:param int num_tasks: The total number of tasks (output functions)
98+
:param int num_latents: The total number of latent functions in each group
99+
:param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
100+
**Must be negative indexed**
101+
:type latent_dim: `int` < 0
78102
"""
79103

80104
def __init__(
@@ -120,28 +144,84 @@ def variational_params_initialized(self):
120144
def kl_divergence(self):
121145
return super().kl_divergence().sum(dim=self.latent_dim)
122146

123-
def __call__(self, x, prior=False, **kwargs):
124-
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
125-
lmc_coefficients = self.lmc_coefficients.expand(*function_dist.batch_shape, self.lmc_coefficients.size(-1))
126-
num_batch = len(function_dist.batch_shape)
127-
num_dim = num_batch + len(function_dist.event_shape)
128-
latent_dim = num_batch + self.latent_dim if self.latent_dim is not None else None
129-
130-
# Mean
131-
mean = function_dist.mean.permute(*range(0, latent_dim), *range(latent_dim + 1, num_dim), latent_dim)
132-
mean = mean @ lmc_coefficients.permute(
133-
*range(0, latent_dim), *range(latent_dim + 1, num_dim - 1), latent_dim, -1
134-
)
135-
136-
# Covar
137-
covar = function_dist.lazy_covariance_matrix
138-
lmc_factor = MatmulLazyTensor(lmc_coefficients.unsqueeze(-1), lmc_coefficients.unsqueeze(-2))
139-
covar = KroneckerProductLazyTensor(covar, lmc_factor)
140-
covar = covar.sum(latent_dim)
141-
142-
# Add a bit of jitter to make the covar PD
143-
covar = covar.add_jitter(1e-6)
144-
145-
# Done!
146-
function_dist = MultitaskMultivariateNormal(mean, covar)
147+
def __call__(self, x, task_indices=None, prior=False, **kwargs):
148+
r"""
149+
Computes the variational (or prior) distribution
150+
:math:`q( \mathbf f \mid \mathbf X)` (or :math:`p( \mathbf f \mid \mathbf X)`).
151+
There are two modes:
152+
153+
1. Compute **all tasks** for all inputs.
154+
If this is the case, the :attr:`task_indices` attribute should be None.
155+
The return type will be a (... x N x num_tasks)
156+
:class:`~gpytorch.distributions.MultitaskMultivariateNormal`.
157+
2. Compute **one task** per inputs.
158+
If this is the case, the (... x N) :attr:`task_indices` tensor should contain
159+
the indices of each input's assigned task.
160+
The return type will be a (... x N)
161+
:class:`~gpytorch.distributions.MultivariateNormal`.
162+
163+
:param x: Input locations to evaluate variational strategy
164+
:type x: torch.Tensor (... x N x D)
165+
:param task_indices: (Default: None) Task index associated with each input.
166+
If this **is not** provided, then the returned distribution evaluates every input on every task
167+
(returns :class:`~gpytorch.distributions.MultitaskMultivariateNormal`).
168+
If this **is** provided, then the returned distribution evaluates each input only on its assigned task.
169+
(returns :class:`~gpytorch.distributions.MultivariateNormal`).
170+
:type task_indices: torch.Tensor (... x N), optional
171+
:param prior: (Default: False) If False, returns the variational distribution
172+
:math:`q( \mathbf f \mid \mathbf X)`.
173+
If True, returns the prior distribution
174+
:math:`p( \mathbf f \mid \mathbf X)`.
175+
:type prior: bool
176+
:return: :math:`q( \mathbf f \mid \mathbf X)` (or the prior),
177+
either for all tasks (if `task_indices == None`)
178+
or for a specific task (if `task_indices != None`).
179+
:rtype: ~gpytorch.distributions.MultitaskMultivariateNormal (... x N x num_tasks)
180+
or ~gpytorch.distributions.MultivariateNormal (... x N)
181+
"""
182+
latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
183+
num_batch = len(latent_dist.batch_shape)
184+
latent_dim = num_batch + self.latent_dim
185+
186+
if task_indices is None:
187+
num_dim = num_batch + len(latent_dist.event_shape)
188+
189+
# Every data point will get an output for each task
190+
# Therefore, we will set up the lmc_coefficients shape for a matmul
191+
lmc_coefficients = self.lmc_coefficients.expand(*latent_dist.batch_shape, self.lmc_coefficients.size(-1))
192+
193+
# Mean: ... x N x num_tasks
194+
latent_mean = latent_dist.mean.permute(*range(0, latent_dim), *range(latent_dim + 1, num_dim), latent_dim)
195+
mean = latent_mean @ lmc_coefficients.permute(
196+
*range(0, latent_dim), *range(latent_dim + 1, num_dim - 1), latent_dim, -1
197+
)
198+
199+
# Covar: ... x (N x num_tasks) x (N x num_tasks)
200+
latent_covar = latent_dist.lazy_covariance_matrix
201+
lmc_factor = RootLazyTensor(lmc_coefficients.unsqueeze(-1))
202+
covar = KroneckerProductLazyTensor(latent_covar, lmc_factor).sum(latent_dim)
203+
# Add a bit of jitter to make the covar PD
204+
covar = covar.add_jitter(settings.cholesky_jitter.value(dtype=mean.dtype))
205+
206+
# Done!
207+
function_dist = MultitaskMultivariateNormal(mean, covar)
208+
209+
else:
210+
# Each data point will get a single output corresponding to a single task
211+
# Therefore, we will select the appropriate lmc coefficients for each task
212+
lmc_coefficients = _select_lmc_coefficients(self.lmc_coefficients, task_indices)
213+
214+
# Mean: ... x N
215+
mean = (latent_dist.mean * lmc_coefficients).sum(latent_dim)
216+
217+
# Covar: ... x N x N
218+
latent_covar = latent_dist.lazy_covariance_matrix
219+
lmc_factor = RootLazyTensor(lmc_coefficients.unsqueeze(-1))
220+
covar = (latent_covar * lmc_factor).sum(latent_dim)
221+
# Add a bit of jitter to make the covar PD
222+
covar = covar.add_jitter(settings.cholesky_jitter.value(dtype=mean.dtype))
223+
224+
# Done!
225+
function_dist = MultivariateNormal(mean, covar)
226+
147227
return function_dist

test/examples/test_lmc_svgp_regression.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import gpytorch
99
import torch
10-
from gpytorch.likelihoods import MultitaskGaussianLikelihood
10+
from gpytorch.likelihoods import GaussianLikelihood, MultitaskGaussianLikelihood
1111

1212

1313
# Batch training test: Let's learn hyperparameters on a sine dataset, but test on a sine dataset and a cosine dataset
@@ -75,7 +75,6 @@ def tearDown(self):
7575
torch.set_rng_state(self.rng_state)
7676

7777
def test_train_and_eval(self):
78-
# We're manually going to set the hyperparameters to something they shouldn't be
7978
likelihood = MultitaskGaussianLikelihood(num_tasks=4)
8079
model = LMCModel()
8180

@@ -132,6 +131,57 @@ def test_train_and_eval(self):
132131
self.assertEqual(lower.shape, train_y.shape)
133132
self.assertEqual(upper.shape, train_y.shape)
134133

134+
def test_indexed_train_and_eval(self):
135+
likelihood = GaussianLikelihood()
136+
model = LMCModel()
137+
138+
# Find optimal model hyperparameters
139+
model.train()
140+
likelihood.train()
141+
optimizer = torch.optim.Adam([
142+
{'params': model.parameters()},
143+
{'params': likelihood.parameters()},
144+
], lr=0.01)
145+
146+
# Our loss object. We're using the VariationalELBO, which essentially just computes the ELBO
147+
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))
148+
149+
# Create some task indices
150+
arange = torch.arange(train_x.size(0))
151+
train_i = torch.rand(train_x.size(0)).mul(4).floor().long()
152+
153+
# We use more CG iterations here because the preconditioner introduced in the NeurIPS paper seems to be less
154+
# effective for VI.
155+
for i in range(400):
156+
# Within each iteration, we will go over each minibatch of data
157+
optimizer.zero_grad()
158+
output = model(train_x, task_indices=train_i)
159+
loss = -mll(output, train_y[arange, train_i])
160+
loss.backward()
161+
optimizer.step()
162+
163+
for param in model.parameters():
164+
self.assertTrue(param.grad is not None)
165+
self.assertGreater(param.grad.norm().item(), 0)
166+
for param in likelihood.parameters():
167+
self.assertTrue(param.grad is not None)
168+
self.assertGreater(param.grad.norm().item(), 0)
169+
170+
# Test the model
171+
model.eval()
172+
likelihood.eval()
173+
174+
# Make predictions for both sets of test points, and check MAEs.
175+
with torch.no_grad(), gpytorch.settings.max_eager_kernel_size(1):
176+
predictions = likelihood(model(train_x, task_indices=train_i))
177+
mean_abs_error = torch.mean(torch.abs(train_y[arange, train_i] - predictions.mean))
178+
self.assertLess(mean_abs_error.squeeze().item(), 0.15)
179+
180+
# Smoke test for getting predictive uncertainties
181+
lower, upper = predictions.confidence_region()
182+
self.assertEqual(lower.shape, train_i.shape)
183+
self.assertEqual(upper.shape, train_i.shape)
184+
135185

136186
if __name__ == "__main__":
137187
unittest.main()

0 commit comments

Comments
 (0)