Skip to content

Commit 3e85dc8

Browse files
committed
Updates from gpytorch 1.6.0-1.8.1
1 parent 2c3dd59 commit 3e85dc8

File tree

8 files changed

+66
-16
lines changed

8 files changed

+66
-16
lines changed

linear_operator/functions/_pivoted_cholesky.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ def backward(ctx, grad_output, _):
103103

104104
with torch.enable_grad():
105105
# Create a new set of matrix args that we can backpropagate through
106-
matrix_args = [matrix_arg.detach().requires_grad_(True) for matrix_arg in _matrix_args]
106+
matrix_args = []
107+
for matrix_arg in _matrix_args:
108+
if matrix_arg.dtype in (torch.float, torch.double, torch.half):
109+
matrix_arg = matrix_arg.detach().requires_grad_(True)
110+
matrix_args.append(matrix_arg)
107111

108112
# Create new linear operator using new matrix args
109113
matrix = ctx.representation_tree(*matrix_args)

linear_operator/operators/_linear_operator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,8 @@ def _set_requires_grad(self, val: bool) -> None:
723723
arg.requires_grad_(val)
724724
for arg in self._kwargs.values():
725725
if hasattr(arg, "requires_grad"):
726-
arg.requires_grad_(val)
726+
if arg.dtype in (torch.float, torch.double, torch.half):
727+
arg.requires_grad_(val)
727728

728729
def _solve(self, rhs: torch.Tensor, preconditioner: Callable, num_tridiag: int = 0) -> torch.Tensor:
729730
r"""
@@ -1759,6 +1760,10 @@ def mul(self, other: Union[float, torch.Tensor, "LinearOperator"]) -> LinearOper
17591760

17601761
return self._mul_matrix(to_linear_operator(other))
17611762

1763+
@property
1764+
def ndim(self) -> int:
1765+
return self.ndimension()
1766+
17621767
def ndimension(self) -> int:
17631768
"""
17641769
Returns the number of dimensions.
@@ -2687,7 +2692,7 @@ def __torch_function__(
26872692
if kwargs is None:
26882693
kwargs = {}
26892694

2690-
if not isinstance(args[0], LinearOperator):
2695+
if not isinstance(args[0], cls):
26912696
if func not in _HANDLED_SECOND_ARG_FUNCTIONS or not all(
26922697
issubclass(t, (torch.Tensor, LinearOperator)) for t in types
26932698
):

linear_operator/operators/block_diag_linear_operator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
# metaclass of BlockDiagLinearOperator, overwrites behavior of constructor call
15-
# _MetaBlockDiagLinearOperator(base_linear_op, block_dim=-3) to return a DiagLazyTensor
15+
# _MetaBlockDiagLinearOperator(base_linear_op, block_dim=-3) to return a DiagLinearOperator
1616
# if base_linear_op is a DiagLinearOperator itself
1717
class _MetaBlockDiagLinearOperator(ABCMeta):
1818
def __call__(cls, base_linear_op: LinearOperator, block_dim=-3):
@@ -136,6 +136,18 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
136136
logdet_res = logdet_res.view(*logdet_res.shape).sum(-1)
137137
return inv_quad_res, logdet_res
138138

139+
def matmul(self, other):
140+
from .diag_linear_operator import DiagLinearOperator
141+
142+
# this is trivial if we multiply two BlockDiagLinearOperator
143+
if isinstance(other, BlockDiagLinearOperator):
144+
return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op)
145+
# special case if we have a DiagLinearOperator
146+
if isinstance(other, DiagLinearOperator):
147+
diag_reshape = other._diag.view(*self.base_linear_op.shape[:-2], 1, -1)
148+
return BlockDiagLinearOperator(self.base_linear_op * diag_reshape)
149+
return super().matmul(other)
150+
139151
@cached(name="svd")
140152
def _svd(self) -> Tuple["LinearOperator", Tensor, "LinearOperator"]:
141153
U, S, V = self.base_linear_op.svd()

linear_operator/operators/cat_linear_operator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ def _get_indices(self, row_index, col_index, *batch_indices):
209209
if len(res_list) == 1:
210210
return res_list[0].view(target_shape).to(self.device)
211211
else:
212-
return torch.cat(res_list).view(target_shape).to(self.device)
212+
# Explicitly move tensors to one device as torch.cat no longer moves tensors:
213+
# https://github.com/pytorch/pytorch/issues/35045
214+
res_list = [linear_op.to(self.device) for linear_op in res_list]
215+
return torch.cat(res_list).view(target_shape)
213216

214217
def _getitem(
215218
self,

linear_operator/operators/diag_linear_operator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .. import settings
99
from ..utils.memoize import cached
1010
from ._linear_operator import LinearOperator
11+
from .block_diag_linear_operator import BlockDiagLinearOperator
1112
from .dense_linear_operator import DenseLinearOperator
1213
from .triangular_linear_operator import TriangularLinearOperator
1314

@@ -163,15 +164,17 @@ def log(self) -> "DiagLinearOperator":
163164
return self.__class__(self._diag.log())
164165

165166
def matmul(self, other: Union[Tensor, LinearOperator]) -> Union[Tensor, LinearOperator]:
166-
from .triangular_linear_operator import TriangularLinearOperator
167-
168167
# this is trivial if we multiply two DiagLinearOperators
169168
if isinstance(other, DiagLinearOperator):
170169
return DiagLinearOperator(self._diag * other._diag)
171170
# special case if we have a DenseLinearOperator
172171
if isinstance(other, DenseLinearOperator):
173172
return DenseLinearOperator(self._diag.unsqueeze(-1) * other.tensor)
174-
# and if we have a triangular one
173+
# special case if we have a BlockDiagLinearOperator
174+
if isinstance(other, BlockDiagLinearOperator):
175+
diag_reshape = self._diag.view(*other.base_linear_op.shape[:-1], 1)
176+
return BlockDiagLinearOperator(diag_reshape * other.base_linear_op)
177+
# special case if we have a TriangularLinearOperator
175178
if isinstance(other, TriangularLinearOperator):
176179
return TriangularLinearOperator(self._diag.unsqueeze(-1) * other._tensor, upper=other.upper)
177180
return super().matmul(other)

linear_operator/operators/interpolated_linear_operator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..utils.interpolation import left_interp, left_t_interp
1313
from ._linear_operator import LinearOperator
1414
from .dense_linear_operator import DenseLinearOperator, to_linear_operator
15+
from .diag_linear_operator import DiagLinearOperator
1516
from .root_linear_operator import RootLinearOperator
1617

1718

@@ -409,6 +410,17 @@ def matmul(self, tensor):
409410
# what we get from the function factory.
410411
# The _matmul_closure is optimized for repeated calls, such as for _solve
411412

413+
if isinstance(tensor, DiagLinearOperator):
414+
# if we know the rhs is diagonal this is easy
415+
new_right_interp_values = self.right_interp_values * tensor._diag.unsqueeze(-1)
416+
return InterpolatedLinearOperator(
417+
base_linear_op=self.base_linear_op,
418+
left_interp_indices=self.left_interp_indices,
419+
left_interp_values=self.left_interp_values,
420+
right_interp_indices=self.right_interp_indices,
421+
right_interp_values=new_right_interp_values,
422+
)
423+
412424
if tensor.ndimension() == 1:
413425
is_vector = True
414426
tensor = tensor.unsqueeze(-1)

linear_operator/test/linear_operator_test_case.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99

1010
import linear_operator
11+
from linear_operator.operators import DiagLinearOperator, to_dense
1112
from linear_operator.settings import linalg_dtypes
1213
from linear_operator.utils.errors import CachingError
1314
from linear_operator.utils.memoize import get_from_cache
@@ -34,14 +35,16 @@ def _test_matmul(self, rhs):
3435
linear_op = self.create_linear_op().detach().requires_grad_(True)
3536
linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True)
3637
evaluated = self.evaluate_linear_op(linear_op_copy)
38+
rhs_evaluated = to_dense(rhs)
3739

3840
# Test operator
3941
res = linear_op @ rhs
40-
actual = evaluated.matmul(rhs)
41-
self.assertAllClose(res, actual)
42+
actual = evaluated.matmul(rhs_evaluated)
43+
res_evaluated = to_dense(res)
44+
self.assertAllClose(res_evaluated, actual)
4245

43-
grad = torch.randn_like(res)
44-
res.backward(gradient=grad)
46+
grad = torch.randn_like(res_evaluated)
47+
res_evaluated.backward(gradient=grad)
4548
actual.backward(gradient=grad)
4649
for arg, arg_copy in zip(linear_op.representation(), linear_op_copy.representation()):
4750
if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None:
@@ -50,7 +53,7 @@ def _test_matmul(self, rhs):
5053
# Test __torch_function__
5154
res = torch.matmul(linear_op, rhs)
5255
actual = evaluated.matmul(rhs)
53-
self.assertAllClose(res, actual)
56+
self.assertAllClose(to_dense(res), actual)
5457

5558
def _test_rmatmul(self, lhs):
5659
linear_op = self.create_linear_op().detach().requires_grad_(True)
@@ -305,6 +308,12 @@ def test_rmatmul_matrix(self):
305308
lhs = torch.randn(*linear_op.batch_shape, 4, linear_op.size(-2))
306309
return self._test_rmatmul(lhs)
307310

311+
def test_matmul_diag_matrix(self):
312+
linear_op = self.create_linear_op()
313+
diag = torch.rand(*linear_op.batch_shape, linear_op.size(-1))
314+
rhs = DiagLinearOperator(diag)
315+
return self._test_matmul(rhs)
316+
308317
def test_matmul_matrix_broadcast(self):
309318
linear_op = self.create_linear_op()
310319

test/operators/test_identity_linear_operator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from linear_operator.operators import IdentityLinearOperator
7+
from linear_operator.operators import IdentityLinearOperator, to_dense
88
from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase
99

1010

@@ -13,10 +13,12 @@ def _test_matmul(self, rhs):
1313
linear_op = self.create_linear_op().detach().requires_grad_(True)
1414
linear_op_copy = linear_op.clone().detach().requires_grad_(True)
1515
evaluated = self.evaluate_linear_op(linear_op_copy)
16+
rhs_evaluated = to_dense(rhs)
1617

1718
res = linear_op.matmul(rhs)
18-
actual = evaluated.matmul(rhs)
19-
self.assertAllClose(res, actual)
19+
actual = evaluated.matmul(rhs_evaluated)
20+
res_evaluated = to_dense(res)
21+
self.assertAllClose(res_evaluated, actual)
2022

2123
def _test_rmatmul(self, lhs):
2224
linear_op = self.create_linear_op().detach().requires_grad_(True)

0 commit comments

Comments
 (0)