Skip to content

Commit 987df55

Browse files
committed
Update docs for AddedDiag, Identity, Zero
1 parent 3e85dc8 commit 987df55

File tree

4 files changed

+132
-92
lines changed

4 files changed

+132
-92
lines changed

docs/source/data_sparse_operators.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ Data-Sparse LinearOperators
3030
.. autoclass:: linear_operator.operators.DiagLinearOperator
3131
:members:
3232

33+
:hidden:`IdentityLinearOperator`
34+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
35+
36+
.. autoclass:: linear_operator.operators.IdentityLinearOperator
37+
:members:
38+
3339
:hidden:`RootLinearOperator`
3440
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3541

linear_operator/operators/added_diag_linear_operator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818

1919
class AddedDiagLinearOperator(SumLinearOperator):
2020
"""
21-
A SumLinearOperator, but of only two linear operators, the second of which must be
22-
a DiagLinearOperator.
21+
A :class:`~linear_operator.operators.SumLinearOperator`, but of only two
22+
linear operators, the second of which must be a
23+
:class:`~linear_operator.operators.DiagLinearOperator`.
2324
2425
:param linear_ops: The LinearOperator, and the DiagLinearOperator to add to it.
2526
:param preconditioner_override: A preconditioning method to be used with conjugate gradients.

linear_operator/operators/identity_linear_operator.py

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Optional, Tuple
5+
from typing import Optional, Tuple, Union
66

77
import torch
88
from torch import Tensor
@@ -11,19 +11,27 @@
1111
from ..utils.memoize import cached
1212
from ._linear_operator import LinearOperator
1313
from .diag_linear_operator import ConstantDiagLinearOperator
14+
from .triangular_linear_operator import TriangularLinearOperator
1415
from .zero_linear_operator import ZeroLinearOperator
1516

1617

1718
class IdentityLinearOperator(ConstantDiagLinearOperator):
18-
def __init__(self, diag_shape, batch_shape=torch.Size([]), dtype=None, device=None):
19-
"""
20-
Identity matrix lazy tensor. Supports arbitrary batch sizes.
21-
22-
Args:
23-
:attr:`diag` (Tensor):
24-
A `b1 x ... x bk x n` Tensor, representing a `b1 x ... x bk`-sized batch
25-
of `n x n` identity matrices
26-
"""
19+
"""
20+
Identity linear operator. Supports arbitrary batch sizes.
21+
22+
:param diag_shape: The size of the identity matrix (i.e. :math:`N`).
23+
:param batch_shape: The size of the batch dimensions. It may useful to set these dimensions for broadcasting.
24+
:param dtype: Dtype that the LinearOperator will be operating on. (Default: :meth:`torch.get_default_dtype()`).
25+
:param device: Device that the LinearOperator will be operating on. (Default: CPU).
26+
"""
27+
28+
def __init__(
29+
self,
30+
diag_shape: int,
31+
batch_shape: Optional[torch.Size] = torch.Size([]),
32+
dtype: Optional[torch.dtype] = None,
33+
device: Optional[torch.device] = None,
34+
):
2735
one = torch.tensor(1.0, dtype=dtype, device=device)
2836
LinearOperator.__init__(self, diag_shape=diag_shape, batch_shape=batch_shape, dtype=dtype, device=device)
2937
self.diag_values = one.expand(torch.Size([*batch_shape, 1]))
@@ -33,40 +41,42 @@ def __init__(self, diag_shape, batch_shape=torch.Size([]), dtype=None, device=No
3341
self._device = device
3442

3543
@property
36-
def batch_shape(self):
37-
"""
38-
Returns the shape over which the tensor is batched.
39-
"""
44+
def batch_shape(self) -> torch.Size:
4045
return self._batch_shape
4146

4247
@property
43-
def dtype(self):
48+
def dtype(self) -> torch.dtype:
4449
return self._dtype
4550

4651
@property
47-
def device(self):
52+
def device(self) -> torch.device:
4853
return self._device
4954

50-
def _maybe_reshape_rhs(self, rhs):
55+
def _maybe_reshape_rhs(self, rhs: torch.Tensor) -> torch.Tensor:
5156
if self._batch_shape != rhs.shape[:-2]:
5257
batch_shape = torch.broadcast_shapes(rhs.shape[:-2], self._batch_shape)
5358
return rhs.expand(*batch_shape, *rhs.shape[-2:])
5459
else:
5560
return rhs
5661

5762
@cached(name="cholesky", ignore_args=True)
58-
def _cholesky(self, upper=False):
63+
def _cholesky(self, upper: Optional[bool] = False) -> TriangularLinearOperator:
5964
return self
6065

61-
def _cholesky_solve(self, rhs):
66+
def _cholesky_solve(self, rhs: torch.Tensor) -> torch.Tensor:
6267
return self._maybe_reshape_rhs(rhs)
6368

64-
def _expand_batch(self, batch_shape):
69+
def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator:
6570
return IdentityLinearOperator(
6671
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device
6772
)
6873

69-
def _getitem(self, row_index, col_index, *batch_indices):
74+
def _getitem(
75+
self,
76+
row_index: Union[slice, torch.LongTensor],
77+
col_index: Union[slice, torch.LongTensor],
78+
*batch_indices: Tuple[Union[int, slice, torch.LongTensor], ...],
79+
) -> LinearOperator:
7080
# Special case: if both row and col are not indexed, then we are done
7181
if _is_noop_index(row_index) and _is_noop_index(col_index):
7282
if len(batch_indices):
@@ -80,35 +90,39 @@ def _getitem(self, row_index, col_index, *batch_indices):
8090

8191
return super()._getitem(row_index, col_index, *batch_indices)
8292

83-
def _matmul(self, rhs):
93+
def _matmul(self, rhs: torch.Tensor) -> torch.Tensor:
8494
return self._maybe_reshape_rhs(rhs)
8595

86-
def _mul_constant(self, constant):
87-
return ConstantDiagLinearOperator(self.diag_values * constant, diag_shape=self.diag_shape)
96+
def _mul_constant(self, other: Union[float, torch.Tensor]) -> LinearOperator:
97+
return ConstantDiagLinearOperator(self.diag_values * other, diag_shape=self.diag_shape)
8898

89-
def _mul_matrix(self, other):
99+
def _mul_matrix(self, other: Union[torch.Tensor, LinearOperator]) -> LinearOperator:
90100
return other
91101

92-
def _permute_batch(self, *dims):
102+
def _permute_batch(self, *dims: Tuple[int, ...]) -> LinearOperator:
93103
batch_shape = self.diag_values.permute(*dims, -1).shape[:-1]
94104
return IdentityLinearOperator(
95105
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self._dtype, device=self._device
96106
)
97107

98-
def _prod_batch(self, dim):
108+
def _prod_batch(self, dim: int) -> LinearOperator:
99109
batch_shape = list(self.batch_shape)
100110
del batch_shape[dim]
101111
return IdentityLinearOperator(
102112
diag_shape=self.diag_shape, batch_shape=torch.Size(batch_shape), dtype=self.dtype, device=self.device
103113
)
104114

105-
def _root_decomposition(self):
115+
def _root_decomposition(self) -> LinearOperator:
106116
return self.sqrt()
107117

108-
def _root_inv_decomposition(self, initial_vectors=None):
118+
def _root_inv_decomposition(
119+
self,
120+
initial_vectors: Optional[torch.Tensor] = None,
121+
test_vectors: Optional[torch.Tensor] = None,
122+
) -> LinearOperator:
109123
return self.inverse().sqrt()
110124

111-
def _size(self):
125+
def _size(self) -> torch.Size:
112126
return torch.Size([*self._batch_shape, self.diag_shape, self.diag_shape])
113127

114128
@cached(name="svd")
@@ -118,10 +132,10 @@ def _svd(self) -> Tuple[LinearOperator, Tensor, LinearOperator]:
118132
def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LinearOperator]]:
119133
return self._diag, self
120134

121-
def _t_matmul(self, rhs):
135+
def _t_matmul(self, rhs: torch.Tensor) -> LinearOperator:
122136
return self._maybe_reshape_rhs(rhs)
123137

124-
def _transpose_nonbatch(self):
138+
def _transpose_nonbatch(self) -> LinearOperator:
125139
return self
126140

127141
def _unsqueeze_batch(self, dim: int) -> IdentityLinearOperator:
@@ -132,16 +146,18 @@ def _unsqueeze_batch(self, dim: int) -> IdentityLinearOperator:
132146
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device
133147
)
134148

135-
def abs(self):
149+
def abs(self) -> LinearOperator:
136150
return self
137151

138-
def exp(self):
152+
def exp(self) -> LinearOperator:
139153
return self
140154

141-
def inverse(self):
155+
def inverse(self) -> LinearOperator:
142156
return self
143157

144-
def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
158+
def inv_quad_logdet(
159+
self, inv_quad_rhs: Optional[torch.Tensor] = None, logdet: bool = False, reduce_inv_quad: bool = True
160+
) -> Tuple[torch.Tensor, torch.Tensor]:
145161
# TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append)
146162
if inv_quad_rhs is None:
147163
inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device)
@@ -158,12 +174,12 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
158174

159175
return inv_quad_term, logdet_term
160176

161-
def log(self):
177+
def log(self) -> LinearOperator:
162178
return ZeroLinearOperator(
163179
*self._batch_shape, self.diag_shape, self.diag_shape, dtype=self._dtype, device=self._device
164180
)
165181

166-
def matmul(self, other):
182+
def matmul(self, other: Union[torch.Tensor, LinearOperator]) -> Union[torch.Tensor, LinearOperator]:
167183
is_vec = False
168184
if other.dim() == 1:
169185
is_vec = True
@@ -173,31 +189,28 @@ def matmul(self, other):
173189
res = res.squeeze(-1)
174190
return res
175191

176-
def solve(self, right_tensor, left_tensor=None):
192+
def solve(self, right_tensor: torch.Tensor, left_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
177193
res = self._maybe_reshape_rhs(right_tensor)
178194
if left_tensor is not None:
179195
res = left_tensor @ res
180196
return res
181197

182-
def sqrt(self):
198+
def sqrt(self) -> LinearOperator:
183199
return self
184200

185-
def sqrt_inv_matmul(self, rhs, lhs=None):
201+
def sqrt_inv_matmul(self, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) -> torch.Tensor:
186202
if lhs is None:
187203
return self._maybe_reshape_rhs(rhs)
188204
else:
189205
sqrt_inv_matmul = lhs @ rhs
190206
inv_quad = lhs.pow(2).sum(dim=-1)
191207
return sqrt_inv_matmul, inv_quad
192208

193-
def type(self, dtype):
194-
"""
195-
This method operates similarly to :func:`torch.Tensor.type`.
196-
"""
209+
def type(self, dtype: torch.dtype) -> LinearOperator:
197210
return IdentityLinearOperator(
198211
diag_shape=self.diag_shape, batch_shape=self.batch_shape, dtype=dtype, device=self.device
199212
)
200213

201-
def zero_mean_mvn_samples(self, num_samples):
214+
def zero_mean_mvn_samples(self, num_samples: int) -> torch.Tensor:
202215
base_samples = torch.randn(num_samples, *self.shape[:-1], dtype=self.dtype, device=self.device)
203216
return base_samples

0 commit comments

Comments
 (0)