Skip to content

Commit 11ed41c

Browse files
Merge pull request #2692 from kayween/fast-variational
Speed Up Variational Strategy
2 parents 60be953 + f29ed7c commit 11ed41c

15 files changed

+539
-40
lines changed

gpytorch/variational/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
IndependentMultitaskVariationalStrategy,
1313
MultitaskVariationalStrategy,
1414
)
15+
from .large_batch_variational_strategy import LargeBatchVariationalStrategy
1516
from .lmc_variational_strategy import LMCVariationalStrategy
1617
from .mean_field_variational_distribution import MeanFieldVariationalDistribution
1718
from .natural_variational_distribution import _NaturalVariationalDistribution, NaturalVariationalDistribution
@@ -29,6 +30,7 @@
2930
"GridInterpolationVariationalStrategy",
3031
"IndependentMultitaskVariationalStrategy",
3132
"LMCVariationalStrategy",
33+
"LargeBatchVariationalStrategy",
3234
"MultitaskVariationalStrategy",
3335
"OrthogonallyDecoupledVariationalStrategy",
3436
"VariationalStrategy",

gpytorch/variational/_variational_strategy.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def forward(
129129
inducing_points: Tensor,
130130
inducing_values: Tensor,
131131
variational_inducing_covar: Optional[LinearOperator] = None,
132+
diag: bool = True,
132133
**kwargs,
133134
) -> MultivariateNormal:
134135
r"""
@@ -146,6 +147,12 @@ def forward(
146147
the distribuiton :math:`q(\mathbf u)` is
147148
Gaussian, then this variable is the covariance matrix of that Gaussian.
148149
Otherwise, it will be None.
150+
:param diag: If true and this module is in train mode, this method is allowed to skip the off-diagonal entries
151+
in the predictive covariance and only compute the predictive variance, whenever it's deemed more efficient
152+
by the underlying implementation. In that case, the off-diagonal entries in the covariance matrix of the
153+
returned :class:`~gpytorch.distributions.MultivariateNormal` could be arbitrary dummy values. If this
154+
argument is false, then this method computes the full covariance matrix even in train mode. This argument
155+
is ignored if this module is in eval mode, in which case the full covariance matrix is always computed.
149156
150157
:rtype: :obj:`~gpytorch.distributions.MultivariateNormal`
151158
:return: The distribution :math:`q( \mathbf f(\mathbf X))`
@@ -320,14 +327,21 @@ def get_fantasy_model(
320327
fantasy_model.prediction_strategy = fant_pred_strat
321328
return fantasy_model
322329

323-
def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
330+
def __call__(self, x: Tensor, prior: bool = False, diag: bool = True, **kwargs) -> MultivariateNormal:
324331
# If we're in prior mode, then we're done!
325332
if prior:
326-
return self.model.forward(x, **kwargs)
333+
if isinstance(self.model, _VariationalStrategy):
334+
# If the model is itself a variational strategy, we need to force it to compute the full covariance in
335+
# case that the model is in train mode.
336+
return self.model.forward(x, diag=False, **kwargs)
337+
else:
338+
# Otherwise, the model is `ApproximateGP`. So we can just call forward.
339+
return self.model.forward(x, **kwargs)
327340

328341
# Delete previously cached items from the training distribution
329342
if self.training:
330343
self._clear_cache()
344+
331345
# (Maybe) initialize variational distribution
332346
if not self.variational_params_initialized.item():
333347
prior_dist = self.prior_distribution
@@ -349,11 +363,17 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNorm
349363
inducing_points,
350364
inducing_values=variational_dist_u.mean,
351365
variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
366+
diag=diag,
352367
**kwargs,
353368
)
354369
elif isinstance(variational_dist_u, Delta):
355370
return super().__call__(
356-
x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=None, **kwargs
371+
x,
372+
inducing_points,
373+
inducing_values=variational_dist_u.mean,
374+
variational_inducing_covar=None,
375+
diag=diag,
376+
**kwargs,
357377
)
358378
else:
359379
raise RuntimeError(

gpytorch/variational/additive_grid_interpolation_variational_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(
6060
inducing_points: Tensor,
6161
inducing_values: Tensor,
6262
variational_inducing_covar: Optional[LinearOperator] = None,
63-
*params,
63+
diag: bool = True,
6464
**kwargs,
6565
) -> MultivariateNormal:
6666
if x.ndimension() == 1:
@@ -72,7 +72,7 @@ def forward(
7272
if num_dim != self.num_dim:
7373
raise RuntimeError("The number of dims should match the number specified.")
7474

75-
output = super().forward(x, inducing_points, inducing_values, variational_inducing_covar)
75+
output = super().forward(x, inducing_points, inducing_values, variational_inducing_covar, diag=diag)
7676
if self.sum_output:
7777
if variational_inducing_covar is not None:
7878
mean = output.mean.sum(0)

gpytorch/variational/batch_decoupled_variational_strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def forward(
181181
inducing_points: Tensor,
182182
inducing_values: Tensor,
183183
variational_inducing_covar: Optional[LinearOperator] = None,
184+
diag: bool = True,
184185
**kwargs,
185186
) -> MultivariateNormal:
186187
# We'll compute the covariance, and cross-covariance terms for both the

gpytorch/variational/ciq_variational_strategy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,12 @@ def forward(
195195
inducing_points: torch.Tensor,
196196
inducing_values: torch.Tensor,
197197
variational_inducing_covar: Optional[LinearOperator] = None,
198-
*params,
198+
diag: bool = True,
199199
**kwargs,
200200
) -> MultivariateNormal:
201201
# Compute full prior distribution
202202
full_inputs = torch.cat([inducing_points, x], dim=-2)
203-
full_output = self.model.forward(full_inputs, *params, **kwargs)
203+
full_output = self.model.forward(full_inputs, **kwargs)
204204
full_covar = full_output.lazy_covariance_matrix
205205

206206
# Covariance terms
@@ -275,7 +275,7 @@ def kl_divergence(self) -> Tensor:
275275
else:
276276
return super().kl_divergence()
277277

278-
def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> MultivariateNormal:
278+
def __call__(self, x: torch.Tensor, prior: bool = False, diag: bool = True, **kwargs) -> MultivariateNormal:
279279
# This is mostly the same as _VariationalStrategy.__call__()
280280
# but with special rules for natural gradient descent (to prevent O(M^3) computation)
281281

@@ -313,7 +313,7 @@ def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> M
313313
inducing_points,
314314
inducing_values=None,
315315
variational_inducing_covar=None,
316-
*params,
316+
diag=diag,
317317
**kwargs,
318318
)
319319
else:
@@ -327,6 +327,7 @@ def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> M
327327
inducing_points,
328328
inducing_values=variational_dist_u.mean,
329329
variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
330+
diag=diag,
330331
**kwargs,
331332
)
332333
elif isinstance(variational_dist_u, Delta):
@@ -336,6 +337,7 @@ def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> M
336337
inducing_points,
337338
inducing_values=variational_dist_u.mean,
338339
variational_inducing_covar=None,
340+
diag=diag,
339341
**kwargs,
340342
)
341343
else:

gpytorch/variational/grid_interpolation_variational_strategy.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#!/usr/bin/env python3
22

3+
from typing import Optional
4+
35
import torch
6+
from linear_operator import LinearOperator
47
from linear_operator.operators import InterpolatedLinearOperator
58
from linear_operator.utils.interpolation import left_interp
9+
from torch import Tensor
610

711
from ..distributions import MultivariateNormal
812
from ..utils.interpolation import Interpolation
@@ -77,7 +81,15 @@ def prior_distribution(self):
7781
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(1e-3))
7882
return res
7983

80-
def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None):
84+
def forward(
85+
self,
86+
x: Tensor,
87+
inducing_points: Tensor,
88+
inducing_values: Tensor,
89+
variational_inducing_covar: Optional[LinearOperator] = None,
90+
diag: bool = True,
91+
**kwargs,
92+
):
8193
if variational_inducing_covar is None:
8294
raise RuntimeError(
8395
"GridInterpolationVariationalStrategy is only compatible with Gaussian variational "

gpytorch/variational/independent_multitask_variational_strategy.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#!/usr/bin/env python3
22

33
import warnings
4+
from typing import Optional
45

56
import torch
67
from linear_operator.operators import RootLinearOperator
8+
from torch import LongTensor, Tensor
79

810
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
911
from ..module import Module
@@ -49,11 +51,19 @@ def variational_params_initialized(self):
4951
def kl_divergence(self):
5052
return super().kl_divergence().sum(dim=-1)
5153

52-
def __call__(self, x, task_indices=None, prior=False, **kwargs):
54+
def __call__(
55+
self,
56+
x: Tensor,
57+
task_indices: Optional[LongTensor] = None,
58+
prior: bool = False,
59+
diag: bool = True,
60+
**kwargs,
61+
):
5362
r"""
5463
See :class:`LMCVariationalStrategy`.
5564
"""
56-
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
65+
# Compute the full covariance because we might use the off-diagonal entries below
66+
function_dist = self.base_variational_strategy(x, prior=prior, diag=False, **kwargs)
5767

5868
if task_indices is None:
5969
# Every data point will get an output for each task
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import torch
2+
3+
from linear_operator.operators import DiagLinearOperator, LinearOperator, MatmulLinearOperator
4+
from torch import Tensor
5+
6+
from gpytorch.variational.variational_strategy import VariationalStrategy
7+
8+
9+
class QuadFormDiagonal(torch.autograd.Function):
10+
r"""A custom autograd function computing the diagonal of a quadratic form.
11+
12+
This function computes `torch.diag(B' A B)` where `A` is a symmetric matrix. The backward pass saves a large matmul
13+
compared to PyTorch's default autograd engine when `B` has way more columns than rows.
14+
"""
15+
16+
@staticmethod
17+
def forward(ctx, matrix: Tensor, rhs: Tensor):
18+
r"""The forward pass computing the diagonal of a quadratic form. Note that it does not form `B' A B` explicitly.
19+
20+
:param matrix: A symmetric matrix of size `(..., M, M)`.
21+
:param rhs: The right-hand side vector of size `(..., M, N)`.
22+
23+
:return: The quadratic form diagonal of size `(..., N)`.
24+
"""
25+
product = matrix @ rhs
26+
27+
# The backward pass does not need `matrix`
28+
ctx.save_for_backward(rhs, product)
29+
30+
return torch.sum(rhs * product, dim=-2)
31+
32+
@staticmethod
33+
def backward(ctx, d_diag: Tensor):
34+
rhs, product = ctx.saved_tensors
35+
36+
d_matrix = rhs @ (d_diag.unsqueeze(-1) * rhs.mT)
37+
d_rhs = 2.0 * product * d_diag.unsqueeze(-2)
38+
39+
return d_matrix, d_rhs
40+
41+
42+
class LargeBatchVariationalStrategy(VariationalStrategy):
43+
r"""A fast variational strategy implementation optimized for large batch stochastic training on data center GPUs.
44+
45+
This implementation has two assumptions on the use case:
46+
1. FP64 operations (in particular triangular solve and matmul) on data center GPUs are not much slower than FP32;
47+
2. The batch size is very large while the number of inducing points is moderate.
48+
49+
This implementation speeds up the standard `VariationalStrategy` in two ways:
50+
1. Group the middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}` when computing the predictive covariance, which saves a
51+
large triangular solve in the forward pass;
52+
2. Use a custom autograd function computing the diagonal of `K_XZ @ middle_term @ K_ZX` in train mode, which saves
53+
a large matmul in the backward pass.
54+
55+
NOTE: Grouping the middle term is not numerically friendly, and thus we have to use double precision to stabilize
56+
the computation. As a result, this implementation is expected to be slow on CPUs and consumer GPUs. Those who use
57+
CPUs and consumer cards should use `VariationalStrategy` instead.
58+
"""
59+
60+
def _compute_predictive_updates(
61+
self,
62+
chol: LinearOperator,
63+
induc_data_covar: Tensor,
64+
inducing_values: Tensor,
65+
variational_inducing_covar: LinearOperator | None,
66+
prior_covar: LinearOperator,
67+
diag: bool = True,
68+
) -> tuple[Tensor, LinearOperator]:
69+
dtype = induc_data_covar.dtype
70+
71+
# Make `K_ZZ^{-1/2}` dense because `TriangularLinearOperator` does not support solve with `left=False`.
72+
chol = chol.to_dense().type(torch.float64)
73+
74+
induc_data_covar = induc_data_covar.type(torch.float64)
75+
inducing_values = inducing_values.type(torch.float64)
76+
77+
# The mean update `k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z)`
78+
inv_chol_t_inducing_values = torch.linalg.solve_triangular(
79+
chol.mT, inducing_values.unsqueeze(-1), upper=True, left=True
80+
)
81+
mean_update = (induc_data_covar.mT @ inv_chol_t_inducing_values).squeeze(-1).type(dtype)
82+
83+
# The grouped middle term `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}`
84+
middle_term = prior_covar.mul(-1).to_dense()
85+
if variational_inducing_covar is not None:
86+
middle_term = variational_inducing_covar.to_dense() + middle_term
87+
middle_term = middle_term.type(torch.float64)
88+
89+
middle_term = torch.linalg.solve_triangular(chol, middle_term, upper=False, left=False)
90+
middle_term = torch.linalg.solve_triangular(chol.mT, middle_term, upper=True, left=True)
91+
92+
# The covariance update `K_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} K_ZX`
93+
if diag and self.training:
94+
# The custom autograd function has a faster backward pass, but it doesn't compute the off-diagonal entries.
95+
variance_update = QuadFormDiagonal.apply(middle_term, induc_data_covar)
96+
covar_update = DiagLinearOperator(diag=variance_update.type(dtype))
97+
else:
98+
covar_update = MatmulLinearOperator(
99+
induc_data_covar.mT.type(dtype), (middle_term @ induc_data_covar).type(dtype)
100+
)
101+
102+
return mean_update, covar_update

gpytorch/variational/lmc_variational_strategy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,13 @@ def kl_divergence(self) -> Tensor:
161161
return super().kl_divergence().sum(dim=self.latent_dim)
162162

163163
def __call__(
164-
self, x: Tensor, prior: bool = False, task_indices: Optional[LongTensor] = None, **kwargs
164+
self,
165+
x: Tensor,
166+
*,
167+
task_indices: Optional[LongTensor] = None,
168+
prior: bool = False,
169+
diag: bool = True,
170+
**kwargs,
165171
) -> Union[MultitaskMultivariateNormal, MultivariateNormal]:
166172
r"""
167173
Computes the variational (or prior) distribution
@@ -194,7 +200,7 @@ def __call__(
194200
:rtype: ~gpytorch.distributions.MultitaskMultivariateNormal (... x N x num_tasks)
195201
or ~gpytorch.distributions.MultivariateNormal (... x N)
196202
"""
197-
latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
203+
latent_dist = self.base_variational_strategy(x, prior=prior, diag=False, **kwargs)
198204
num_batch = len(latent_dist.batch_shape)
199205
latent_dim = num_batch + self.latent_dim
200206

gpytorch/variational/nearest_neighbor_variational_strategy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _cholesky_factor(
133133
return TriangularLinearOperator(L)
134134

135135
def __call__(
136-
self, x: Float[Tensor, "... N D"], prior: bool = False, **kwargs: Any
136+
self, x: Float[Tensor, "... N D"], prior: bool = False, diag: bool = True, **kwargs: Any
137137
) -> Float[MultivariateNormal, "... N"]:
138138
# If we're in prior mode, then we're done!
139139
if prior:
@@ -180,8 +180,11 @@ def forward(
180180
inducing_points: Float[Tensor, "... M D"],
181181
inducing_values: Float[Tensor, "... M"],
182182
variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None,
183+
diag: bool = True,
183184
**kwargs: Any,
184185
) -> Float[MultivariateNormal, "... N"]:
186+
# TODO: This method needs to return the full covariance in eval mode, not just the predictive variance.
187+
# TODO: Use `diag` to control when to compute the variance vs. covariance in train mode.
185188
if self.training:
186189
# In training mode, note that the full inducing points set = full training dataset
187190
# Users have the option to choose input None or a tensor of training data for x

0 commit comments

Comments
 (0)