|
9 | 9 | import torch |
10 | 10 |
|
11 | 11 | import gpytorch |
| 12 | +from gpytorch.settings import linalg_dtypes |
12 | 13 | from gpytorch.utils.cholesky import CHOLESKY_METHOD |
13 | 14 |
|
14 | 15 | from .base_test_case import BaseTestCase |
@@ -295,7 +296,7 @@ class LazyTensorTestCase(RectangularLazyTensorTestCase): |
295 | 296 | "root_inv_decomposition": {"rtol": 0.05, "atol": 0.02}, |
296 | 297 | "sample": {"rtol": 0.3, "atol": 0.3}, |
297 | 298 | "sqrt_inv_matmul": {"rtol": 1e-4, "atol": 1e-3}, |
298 | | - "symeig": {"rtol": 1e-4, "atol": 1e-3}, |
| 299 | + "symeig": {"double": {"rtol": 1e-4, "atol": 1e-3}, "float": {"rtol": 1e-3, "atol": 1e-2}}, |
299 | 300 | "svd": {"rtol": 1e-4, "atol": 1e-3}, |
300 | 301 | } |
301 | 302 |
|
@@ -754,51 +755,56 @@ def test_sqrt_inv_matmul_no_lhs(self): |
754 | 755 | self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["sqrt_inv_matmul"]) |
755 | 756 |
|
756 | 757 | def test_symeig(self): |
757 | | - lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True) |
758 | | - lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(True) |
759 | | - evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy) |
760 | | - |
761 | | - # Perform forward pass |
762 | | - evals_unsorted, evecs_unsorted = lazy_tensor.symeig(eigenvectors=True) |
763 | | - evecs_unsorted = evecs_unsorted.evaluate() |
764 | | - |
765 | | - # since LazyTensor.symeig does not sort evals, we do this here for the check |
766 | | - evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False) |
767 | | - evecs = torch.gather(evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape)) |
768 | | - |
769 | | - evals_actual, evecs_actual = torch.linalg.eigh(evaluated.double()) |
770 | | - evals_actual = evals_actual.to(dtype=evaluated.dtype) |
771 | | - evecs_actual = evecs_actual.to(dtype=evaluated.dtype) |
772 | | - |
773 | | - # Check forward pass |
774 | | - self.assertAllClose(evals, evals_actual, **self.tolerances["symeig"]) |
775 | | - lt_from_eigendecomp = evecs @ torch.diag_embed(evals) @ evecs.transpose(-1, -2) |
776 | | - self.assertAllClose(lt_from_eigendecomp, evaluated, **self.tolerances["symeig"]) |
777 | | - |
778 | | - # if there are repeated evals, we'll skip checking the eigenvectors for those |
779 | | - any_evals_repeated = False |
780 | | - evecs_abs, evecs_actual_abs = evecs.abs(), evecs_actual.abs() |
781 | | - for idx in itertools.product(*[range(b) for b in evals_actual.shape[:-1]]): |
782 | | - eval_i = evals_actual[idx] |
783 | | - if torch.unique(eval_i.detach()).shape[-1] == eval_i.shape[-1]: # detach to avoid pytorch/pytorch#41389 |
784 | | - self.assertAllClose(evecs_abs[idx], evecs_actual_abs[idx], **self.tolerances["symeig"]) |
785 | | - else: |
786 | | - any_evals_repeated = True |
| 758 | + dtypes = {"double": torch.double, "float": torch.float} |
| 759 | + for name, dtype in dtypes.items(): |
| 760 | + tolerances = self.tolerances["symeig"][name] |
| 761 | + |
| 762 | + lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True) |
| 763 | + lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_(True) |
| 764 | + evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy) |
| 765 | + |
| 766 | + # Perform forward pass |
| 767 | + with linalg_dtypes(dtype): |
| 768 | + evals_unsorted, evecs_unsorted = lazy_tensor.symeig(eigenvectors=True) |
| 769 | + evecs_unsorted = evecs_unsorted.evaluate() |
| 770 | + |
| 771 | + # since LazyTensor.symeig does not sort evals, we do this here for the check |
| 772 | + evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False) |
| 773 | + evecs = torch.gather(evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape)) |
| 774 | + |
| 775 | + evals_actual, evecs_actual = torch.linalg.eigh(evaluated.type(dtype)) |
| 776 | + evals_actual = evals_actual.to(dtype=evaluated.dtype) |
| 777 | + evecs_actual = evecs_actual.to(dtype=evaluated.dtype) |
| 778 | + |
| 779 | + # Check forward pass |
| 780 | + self.assertAllClose(evals, evals_actual, **tolerances) |
| 781 | + lt_from_eigendecomp = evecs @ torch.diag_embed(evals) @ evecs.transpose(-1, -2) |
| 782 | + self.assertAllClose(lt_from_eigendecomp, evaluated, **tolerances) |
| 783 | + |
| 784 | + # if there are repeated evals, we'll skip checking the eigenvectors for those |
| 785 | + any_evals_repeated = False |
| 786 | + evecs_abs, evecs_actual_abs = evecs.abs(), evecs_actual.abs() |
| 787 | + for idx in itertools.product(*[range(b) for b in evals_actual.shape[:-1]]): |
| 788 | + eval_i = evals_actual[idx] |
| 789 | + if torch.unique(eval_i.detach()).shape[-1] == eval_i.shape[-1]: # detach to avoid pytorch/pytorch#41389 |
| 790 | + self.assertAllClose(evecs_abs[idx], evecs_actual_abs[idx], **tolerances) |
| 791 | + else: |
| 792 | + any_evals_repeated = True |
787 | 793 |
|
788 | | - # Perform backward pass |
789 | | - symeig_grad = torch.randn_like(evals) |
790 | | - ((evals * symeig_grad).sum()).backward() |
791 | | - ((evals_actual * symeig_grad).sum()).backward() |
| 794 | + # Perform backward pass |
| 795 | + symeig_grad = torch.randn_like(evals) |
| 796 | + ((evals * symeig_grad).sum()).backward() |
| 797 | + ((evals_actual * symeig_grad).sum()).backward() |
792 | 798 |
|
793 | | - # Check grads if there were no repeated evals |
794 | | - if not any_evals_repeated: |
795 | | - for arg, arg_copy in zip(lazy_tensor.representation(), lazy_tensor_copy.representation()): |
796 | | - if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None: |
797 | | - self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["symeig"]) |
| 799 | + # Check grads if there were no repeated evals |
| 800 | + if not any_evals_repeated: |
| 801 | + for arg, arg_copy in zip(lazy_tensor.representation(), lazy_tensor_copy.representation()): |
| 802 | + if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None: |
| 803 | + self.assertAllClose(arg.grad, arg_copy.grad, **tolerances) |
798 | 804 |
|
799 | | - # Test with eigenvectors=False |
800 | | - _, evecs = lazy_tensor.symeig(eigenvectors=False) |
801 | | - self.assertIsNone(evecs) |
| 805 | + # Test with eigenvectors=False |
| 806 | + _, evecs = lazy_tensor.symeig(eigenvectors=False) |
| 807 | + self.assertIsNone(evecs) |
802 | 808 |
|
803 | 809 | def test_svd(self): |
804 | 810 | lazy_tensor = self.create_lazy_tensor().detach().requires_grad_(True) |
|
0 commit comments