Skip to content

Commit e4579ed

Browse files
wjmaddoxBalandat
andauthored
Add dtype registry for symeig (#1725)
* add dtype registry for symeig * move to a linalg_dtypes operator * Update gpytorch/test/lazy_tensor_test_case.py Co-authored-by: Max Balandat <[email protected]> * rename linalg dtypes Co-authored-by: Max Balandat <[email protected]>
1 parent a6c5b02 commit e4579ed

File tree

5 files changed

+104
-61
lines changed

5 files changed

+104
-61
lines changed

gpytorch/lazy/kronecker_product_added_diag_lazy_tensor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,17 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
118118

119119
rhs_dtype = rhs.dtype
120120

121+
# we perform the solve in double for numerical stability issues
122+
symeig_dtype = settings._linalg_dtype_symeig.value()
123+
121124
# if the diagonal is constant, we can solve this using the Kronecker-structured eigendecomposition
122125
# and performing a spectral shift of its eigenvalues
123126
if self._diag_is_constant:
124-
# we perform the solve in double for numerical stability issues
125-
# TODO: Use fp64 registry once #1213 is addressed
126-
evals, q_matrix = self.lazy_tensor.to(torch.double).diagonalization()
127-
evals_plus_diagonal = evals + self.diag_tensor.diag().double()
127+
evals, q_matrix = self.lazy_tensor.to(symeig_dtype).diagonalization()
128+
evals_plus_diagonal = evals + self.diag_tensor.diag().to(symeig_dtype)
128129
evals_root = evals_plus_diagonal.pow(0.5)
129130
inv_mat_sqrt = DiagLazyTensor(evals_root.reciprocal())
130-
res = q_matrix.transpose(-2, -1).matmul(rhs.double())
131+
res = q_matrix.transpose(-2, -1).matmul(rhs.to(symeig_dtype))
131132
res2 = inv_mat_sqrt.matmul(res)
132133
lazy_lhs = q_matrix.matmul(inv_mat_sqrt)
133134
return lazy_lhs.matmul(res2).type(rhs_dtype)
@@ -154,9 +155,9 @@ def _solve(self, rhs, preconditioner=None, num_tridiag=0):
154155

155156
# again we perform the solve in double precision for numerical stability issues
156157
# TODO: Use fp64 registry once #1213 is addressed
157-
rhs = rhs.double()
158-
lt = self.lazy_tensor.to(torch.double)
159-
dlt = self.diag_tensor.to(torch.double)
158+
rhs = rhs.to(symeig_dtype)
159+
lt = self.lazy_tensor.to(symeig_dtype)
160+
dlt = self.diag_tensor.to(symeig_dtype)
160161

161162
# If each of the diagonal factors is constant, life gets a little easier
162163
# as we can reuse the eigendecomposition

gpytorch/lazy/lazy_tensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,10 +2170,11 @@ def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTen
21702170
if settings.verbose_linalg.on():
21712171
settings.verbose_linalg.logger.debug(f"Running symeig on a matrix of size {self.shape}.")
21722172

2173-
dtype = self.dtype # perform decomposition in double precision for numerical stability
2174-
# TODO: Use fp64 registry once #1213 is addressed
2175-
evals, evecs = torch.linalg.eigh(self.evaluate().to(dtype=torch.double))
2176-
# chop any negative eigenvalues. TODO: warn if evals are significantly negative
2173+
# potentially perform decomposition in double precision for numerical stability
2174+
dtype = self.dtype
2175+
evals, evecs = torch.linalg.eigh(self.evaluate().to(dtype=settings._linalg_dtype_symeig.value()))
2176+
# chop any negative eigenvalues.
2177+
# TODO: warn if evals are significantly negative
21772178
evals = evals.clamp_min(0.0).to(dtype=dtype)
21782179
if eigenvectors:
21792180
evecs = NonLazyTensor(evecs.to(dtype=dtype))

gpytorch/settings.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ class skip_logdet_forward(_feature_flag):
671671
pass will skip certain computations (i.e. the logdet computation), and will therefore
672672
be improper estimates.
673673
674-
If you're using SGD (or a varient) to optimize parameters, you probably
674+
If you're using SGD (or a variant) to optimize parameters, you probably
675675
don't need an accurate MLL estimate; you only need accurate gradients. So
676676
this setting may give your model a performance boost.
677677
@@ -681,6 +681,40 @@ class skip_logdet_forward(_feature_flag):
681681
_default = False
682682

683683

684+
class _linalg_dtype_symeig(_value_context):
685+
_global_value = torch.double
686+
687+
688+
class _linalg_dtype_cholesky(_value_context):
689+
_global_value = torch.double
690+
691+
692+
class linalg_dtypes:
693+
"""
694+
Whether to perform less stable linalg calls in double precision or in a lower precision.
695+
Currently, the default is to apply all symeig calls and cholesky calls within variational
696+
methods in double precision.
697+
698+
(Default: torch.double)
699+
"""
700+
701+
def __init__(self, default=torch.double, symeig=None, cholesky=None):
702+
symeig = default if symeig is None else symeig
703+
cholesky = default if cholesky is None else cholesky
704+
705+
self.symeig = _linalg_dtype_symeig(symeig)
706+
self.cholesky = _linalg_dtype_cholesky(cholesky)
707+
708+
def __enter__(self):
709+
self.symeig.__enter__()
710+
self.cholesky.__enter__()
711+
712+
def __exit__(self, *args):
713+
self.symeig.__exit__()
714+
self.cholesky.__exit__()
715+
return False
716+
717+
684718
class terminate_cg_by_size(_feature_flag):
685719
"""
686720
If set to true, cg will terminate after n iterations for an n x n matrix.

gpytorch/test/lazy_tensor_test_case.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
import gpytorch
12+
from gpytorch.settings import linalg_dtypes
1213
from gpytorch.utils.cholesky import CHOLESKY_METHOD
1314

1415
from .base_test_case import BaseTestCase
@@ -295,7 +296,7 @@ class LazyTensorTestCase(RectangularLazyTensorTestCase):
295296
"root_inv_decomposition": {"rtol": 0.05, "atol": 0.02},
296297
"sample": {"rtol": 0.3, "atol": 0.3},
297298
"sqrt_inv_matmul": {"rtol": 1e-4, "atol": 1e-3},
298-
"symeig": {"rtol": 1e-4, "atol": 1e-3},
299+
"symeig": {"double": {"rtol": 1e-4, "atol": 1e-3}, "float": {"rtol": 1e-3, "atol": 1e-2}},
299300
"svd": {"rtol": 1e-4, "atol": 1e-3},
300301
}
301302

@@ -754,51 +755,56 @@ def test_sqrt_inv_matmul_no_lhs(self):
754755
self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["sqrt_inv_matmul"])
755756

756757
def test_symeig(self):
757-
lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True)
758-
lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(True)
759-
evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy)
760-
761-
# Perform forward pass
762-
evals_unsorted, evecs_unsorted = lazy_tensor.symeig(eigenvectors=True)
763-
evecs_unsorted = evecs_unsorted.evaluate()
764-
765-
# since LazyTensor.symeig does not sort evals, we do this here for the check
766-
evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False)
767-
evecs = torch.gather(evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape))
768-
769-
evals_actual, evecs_actual = torch.linalg.eigh(evaluated.double())
770-
evals_actual = evals_actual.to(dtype=evaluated.dtype)
771-
evecs_actual = evecs_actual.to(dtype=evaluated.dtype)
772-
773-
# Check forward pass
774-
self.assertAllClose(evals, evals_actual, **self.tolerances["symeig"])
775-
lt_from_eigendecomp = evecs @ torch.diag_embed(evals) @ evecs.transpose(-1, -2)
776-
self.assertAllClose(lt_from_eigendecomp, evaluated, **self.tolerances["symeig"])
777-
778-
# if there are repeated evals, we'll skip checking the eigenvectors for those
779-
any_evals_repeated = False
780-
evecs_abs, evecs_actual_abs = evecs.abs(), evecs_actual.abs()
781-
for idx in itertools.product(*[range(b) for b in evals_actual.shape[:-1]]):
782-
eval_i = evals_actual[idx]
783-
if torch.unique(eval_i.detach()).shape[-1] == eval_i.shape[-1]: # detach to avoid pytorch/pytorch#41389
784-
self.assertAllClose(evecs_abs[idx], evecs_actual_abs[idx], **self.tolerances["symeig"])
785-
else:
786-
any_evals_repeated = True
758+
dtypes = {"double": torch.double, "float": torch.float}
759+
for name, dtype in dtypes.items():
760+
tolerances = self.tolerances["symeig"][name]
761+
762+
lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True)
763+
lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(True)
764+
evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy)
765+
766+
# Perform forward pass
767+
with linalg_dtypes(dtype):
768+
evals_unsorted, evecs_unsorted = lazy_tensor.symeig(eigenvectors=True)
769+
evecs_unsorted = evecs_unsorted.evaluate()
770+
771+
# since LazyTensor.symeig does not sort evals, we do this here for the check
772+
evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False)
773+
evecs = torch.gather(evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape))
774+
775+
evals_actual, evecs_actual = torch.linalg.eigh(evaluated.type(dtype))
776+
evals_actual = evals_actual.to(dtype=evaluated.dtype)
777+
evecs_actual = evecs_actual.to(dtype=evaluated.dtype)
778+
779+
# Check forward pass
780+
self.assertAllClose(evals, evals_actual, **tolerances)
781+
lt_from_eigendecomp = evecs @ torch.diag_embed(evals) @ evecs.transpose(-1, -2)
782+
self.assertAllClose(lt_from_eigendecomp, evaluated, **tolerances)
783+
784+
# if there are repeated evals, we'll skip checking the eigenvectors for those
785+
any_evals_repeated = False
786+
evecs_abs, evecs_actual_abs = evecs.abs(), evecs_actual.abs()
787+
for idx in itertools.product(*[range(b) for b in evals_actual.shape[:-1]]):
788+
eval_i = evals_actual[idx]
789+
if torch.unique(eval_i.detach()).shape[-1] == eval_i.shape[-1]: # detach to avoid pytorch/pytorch#41389
790+
self.assertAllClose(evecs_abs[idx], evecs_actual_abs[idx], **tolerances)
791+
else:
792+
any_evals_repeated = True
787793

788-
# Perform backward pass
789-
symeig_grad = torch.randn_like(evals)
790-
((evals * symeig_grad).sum()).backward()
791-
((evals_actual * symeig_grad).sum()).backward()
794+
# Perform backward pass
795+
symeig_grad = torch.randn_like(evals)
796+
((evals * symeig_grad).sum()).backward()
797+
((evals_actual * symeig_grad).sum()).backward()
792798

793-
# Check grads if there were no repeated evals
794-
if not any_evals_repeated:
795-
for arg, arg_copy in zip(lazy_tensor.representation(), lazy_tensor_copy.representation()):
796-
if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
797-
self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["symeig"])
799+
# Check grads if there were no repeated evals
800+
if not any_evals_repeated:
801+
for arg, arg_copy in zip(lazy_tensor.representation(), lazy_tensor_copy.representation()):
802+
if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
803+
self.assertAllClose(arg.grad, arg_copy.grad, **tolerances)
798804

799-
# Test with eigenvectors=False
800-
_, evecs = lazy_tensor.symeig(eigenvectors=False)
801-
self.assertIsNone(evecs)
805+
# Test with eigenvectors=False
806+
_, evecs = lazy_tensor.symeig(eigenvectors=False)
807+
self.assertIsNone(evecs)
802808

803809
def test_svd(self):
804810
lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True)

gpytorch/variational/variational_strategy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..distributions import MultivariateNormal
88
from ..lazy import DiagLazyTensor, MatmulLazyTensor, RootLazyTensor, SumLazyTensor, TriangularLazyTensor, delazify
9-
from ..settings import trace_mode
9+
from ..settings import _linalg_dtype_cholesky, trace_mode
1010
from ..utils.cholesky import psd_safe_cholesky
1111
from ..utils.errors import CachingError
1212
from ..utils.memoize import cached, clear_cache_hook, pop_from_cache_ignore_args
@@ -69,7 +69,7 @@ def __init__(self, model, inducing_points, variational_distribution, learn_induc
6969

7070
@cached(name="cholesky_factor", ignore_args=True)
7171
def _cholesky_factor(self, induc_induc_covar):
72-
L = psd_safe_cholesky(delazify(induc_induc_covar).double())
72+
L = psd_safe_cholesky(delazify(induc_induc_covar).type(_linalg_dtype_cholesky.value()))
7373
return TriangularLazyTensor(L)
7474

7575
@property
@@ -109,7 +109,7 @@ def forward(self, x, inducing_points, inducing_values, variational_inducing_cova
109109
except CachingError:
110110
pass
111111
L = self._cholesky_factor(induc_induc_covar)
112-
interp_term = L.inv_matmul(induc_data_covar.double()).to(full_inputs.dtype)
112+
interp_term = L.inv_matmul(induc_data_covar.type(_linalg_dtype_cholesky.value())).to(full_inputs.dtype)
113113

114114
# Compute the mean of q(f)
115115
# k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
@@ -149,9 +149,10 @@ def __call__(self, x, prior=False, **kwargs):
149149

150150
# Change the variational parameters to be whitened
151151
variational_dist = self.variational_distribution
152-
mean_diff = (variational_dist.loc - prior_mean).unsqueeze(-1).double()
152+
mean_diff = (variational_dist.loc - prior_mean).unsqueeze(-1).type(_linalg_dtype_cholesky.value())
153153
whitened_mean = L.inv_matmul(mean_diff).squeeze(-1).to(variational_dist.loc.dtype)
154-
covar_root = variational_dist.lazy_covariance_matrix.root_decomposition().root.evaluate().double()
154+
covar_root = variational_dist.lazy_covariance_matrix.root_decomposition().root.evaluate()
155+
covar_root = covar_root.type(_linalg_dtype_cholesky.value())
155156
whitened_covar = RootLazyTensor(L.inv_matmul(covar_root).to(variational_dist.loc.dtype))
156157
whitened_variational_distribution = variational_dist.__class__(whitened_mean, whitened_covar)
157158
self._variational_distribution.initialize_variational_distribution(whitened_variational_distribution)

0 commit comments

Comments
 (0)