Skip to content

Commit 0907c95

Browse files
authored
Implement variance reduction in SLQ logdet backward pass. (#1836)
* Pivoted Cholesky is an autograd function * Identity lazy tensor * Change return type of _preconditioner function * Differentiate through preconditioner logdet for variance reduction * Deprecation warning for deterministic probes. * Add log and exp functions to IdentityLazyTensor * Revert "Change return type of _preconditioner function" This reverts commit d2bff48.
1 parent 5a0ff6b commit 0907c95

22 files changed

+1056
-562
lines changed

docs/source/functions.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ Functions
3030

3131
.. autofunction:: log_normal_cdf
3232

33+
.. autofunction:: pivoted_cholesky
34+
3335
.. autofunction:: root_decomposition
3436

3537
.. autofunction:: root_inv_decomposition

docs/source/utils.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ Lanczos Utilities
1919
.. automodule:: gpytorch.utils.lanczos
2020
:members:
2121

22-
Pivoted Cholesky Utilities
22+
Permutation Utilities
2323
~~~~~~~~~~~~~~~~~~~~~~~~~~
2424

25-
.. automodule:: gpytorch.utils.pivoted_cholesky
25+
.. automodule:: gpytorch.utils.permutation
2626
:members:
2727

2828
Quadrature Utilities

gpytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
log_normal_cdf,
2525
logdet,
2626
matmul,
27+
pivoted_cholesky,
2728
root_decomposition,
2829
root_inv_decomposition,
2930
)
@@ -62,6 +63,7 @@
6263
"logdet",
6364
"log_normal_cdf",
6465
"matmul",
66+
"pivoted_cholesky",
6567
"root_decomposition",
6668
"root_inv_decomposition",
6769
# Context managers

gpytorch/functions/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,36 @@ def logdet(mat):
167167
return res
168168

169169

170+
def pivoted_cholesky(mat, rank, error_tol=None, return_pivots=None):
171+
r"""
172+
Performs a partial pivoted Cholesky factorization of the (positive definite) matrix.
173+
:math:`\mathbf L \mathbf L^\top = \mathbf K`.
174+
The partial pivoted Cholesky factor :math:`\mathbf L \in \mathbb R^{N \times \text{rank}}`
175+
forms a low rank approximation to the matrix.
176+
177+
The pivots are selected greedily, corresponding to the maximum diagonal element in the
178+
residual after each Cholesky iteration. See `Harbrecht et al., 2012`_.
179+
180+
:param mat: The matrix :math:`\mathbf K` to decompose
181+
:type mat: ~gpytorch.lazy.LazyTensor or ~torch.Tensor
182+
:param int rank: The size of the partial pivoted Cholesky factor.
183+
:param error_tol: Defines an optional stopping criterion.
184+
If the residual of the factorization is less than :attr:`error_tol`, then the
185+
factorization will exit early. This will result in a :math:`\leq \text{ rank}` factor.
186+
:type error_tol: float, optional
187+
:param bool return_pivots: (default: False) Whether or not to return the pivots alongside
188+
the partial pivoted Cholesky factor.
189+
:return: the `... x N x rank` factor (and optionally the `... x N` pivots)
190+
:rtype: torch.Tensor or tuple(torch.Tensor, torch.Tensor)
191+
192+
.. _Harbrecht et al., 2012:
193+
https://www.sciencedirect.com/science/article/pii/S0168927411001814
194+
"""
195+
from ..lazy import lazify
196+
197+
return lazify(mat).pivoted_cholesky(rank=rank, error_tol=error_tol, return_pivots=return_pivots)
198+
199+
170200
def root_decomposition(mat):
171201
"""
172202
Returns a (usually low-rank) root decomposotion lazy tensor of a PSD matrix.
@@ -201,6 +231,7 @@ def root_inv_decomposition(mat, initial_vectors=None, test_vectors=None):
201231
"log_normal_cdf",
202232
"matmul",
203233
"normal_cdf",
234+
"pivoted_cholesky",
204235
"root_decomposition",
205236
"root_inv_decomposition",
206237
# Deprecated

gpytorch/functions/_inv_quad_log_det.py

Lines changed: 0 additions & 268 deletions
This file was deleted.

0 commit comments

Comments
 (0)