22
33from __future__ import annotations
44
5- from typing import Optional , Tuple
5+ from typing import Optional , Tuple , Union
66
77import torch
88from torch import Tensor
1111from ..utils .memoize import cached
1212from ._linear_operator import LinearOperator
1313from .diag_linear_operator import ConstantDiagLinearOperator
14+ from .triangular_linear_operator import TriangularLinearOperator
1415from .zero_linear_operator import ZeroLinearOperator
1516
1617
1718class IdentityLinearOperator (ConstantDiagLinearOperator ):
18- def __init__ (self , diag_shape , batch_shape = torch .Size ([]), dtype = None , device = None ):
19- """
20- Identity matrix lazy tensor. Supports arbitrary batch sizes.
21-
22- Args:
23- :attr:`diag` (Tensor):
24- A `b1 x ... x bk x n` Tensor, representing a `b1 x ... x bk`-sized batch
25- of `n x n` identity matrices
26- """
19+ """
20+ Identity linear operator. Supports arbitrary batch sizes.
21+
22+ :param diag_shape: The size of the identity matrix (i.e. :math:`N`).
23+ :param batch_shape: The size of the batch dimensions. It may useful to set these dimensions for broadcasting.
24+ :param dtype: Dtype that the LinearOperator will be operating on. (Default: :meth:`torch.get_default_dtype()`).
25+ :param device: Device that the LinearOperator will be operating on. (Default: CPU).
26+ """
27+
28+ def __init__ (
29+ self ,
30+ diag_shape : int ,
31+ batch_shape : Optional [torch .Size ] = torch .Size ([]),
32+ dtype : Optional [torch .dtype ] = None ,
33+ device : Optional [torch .device ] = None ,
34+ ):
2735 one = torch .tensor (1.0 , dtype = dtype , device = device )
2836 LinearOperator .__init__ (self , diag_shape = diag_shape , batch_shape = batch_shape , dtype = dtype , device = device )
2937 self .diag_values = one .expand (torch .Size ([* batch_shape , 1 ]))
@@ -33,40 +41,42 @@ def __init__(self, diag_shape, batch_shape=torch.Size([]), dtype=None, device=No
3341 self ._device = device
3442
3543 @property
36- def batch_shape (self ):
37- """
38- Returns the shape over which the tensor is batched.
39- """
44+ def batch_shape (self ) -> torch .Size :
4045 return self ._batch_shape
4146
4247 @property
43- def dtype (self ):
48+ def dtype (self ) -> torch . dtype :
4449 return self ._dtype
4550
4651 @property
47- def device (self ):
52+ def device (self ) -> torch . device :
4853 return self ._device
4954
50- def _maybe_reshape_rhs (self , rhs ) :
55+ def _maybe_reshape_rhs (self , rhs : torch . Tensor ) -> torch . Tensor :
5156 if self ._batch_shape != rhs .shape [:- 2 ]:
5257 batch_shape = torch .broadcast_shapes (rhs .shape [:- 2 ], self ._batch_shape )
5358 return rhs .expand (* batch_shape , * rhs .shape [- 2 :])
5459 else :
5560 return rhs
5661
5762 @cached (name = "cholesky" , ignore_args = True )
58- def _cholesky (self , upper = False ):
63+ def _cholesky (self , upper : Optional [ bool ] = False ) -> TriangularLinearOperator :
5964 return self
6065
61- def _cholesky_solve (self , rhs ) :
66+ def _cholesky_solve (self , rhs : torch . Tensor ) -> torch . Tensor :
6267 return self ._maybe_reshape_rhs (rhs )
6368
64- def _expand_batch (self , batch_shape ) :
69+ def _expand_batch (self , batch_shape : torch . Size ) -> LinearOperator :
6570 return IdentityLinearOperator (
6671 diag_shape = self .diag_shape , batch_shape = batch_shape , dtype = self .dtype , device = self .device
6772 )
6873
69- def _getitem (self , row_index , col_index , * batch_indices ):
74+ def _getitem (
75+ self ,
76+ row_index : Union [slice , torch .LongTensor ],
77+ col_index : Union [slice , torch .LongTensor ],
78+ * batch_indices : Tuple [Union [int , slice , torch .LongTensor ], ...],
79+ ) -> LinearOperator :
7080 # Special case: if both row and col are not indexed, then we are done
7181 if _is_noop_index (row_index ) and _is_noop_index (col_index ):
7282 if len (batch_indices ):
@@ -80,35 +90,39 @@ def _getitem(self, row_index, col_index, *batch_indices):
8090
8191 return super ()._getitem (row_index , col_index , * batch_indices )
8292
83- def _matmul (self , rhs ) :
93+ def _matmul (self , rhs : torch . Tensor ) -> torch . Tensor :
8494 return self ._maybe_reshape_rhs (rhs )
8595
86- def _mul_constant (self , constant ) :
87- return ConstantDiagLinearOperator (self .diag_values * constant , diag_shape = self .diag_shape )
96+ def _mul_constant (self , other : Union [ float , torch . Tensor ]) -> LinearOperator :
97+ return ConstantDiagLinearOperator (self .diag_values * other , diag_shape = self .diag_shape )
8898
89- def _mul_matrix (self , other ) :
99+ def _mul_matrix (self , other : Union [ torch . Tensor , LinearOperator ]) -> LinearOperator :
90100 return other
91101
92- def _permute_batch (self , * dims ) :
102+ def _permute_batch (self , * dims : Tuple [ int , ...]) -> LinearOperator :
93103 batch_shape = self .diag_values .permute (* dims , - 1 ).shape [:- 1 ]
94104 return IdentityLinearOperator (
95105 diag_shape = self .diag_shape , batch_shape = batch_shape , dtype = self ._dtype , device = self ._device
96106 )
97107
98- def _prod_batch (self , dim ) :
108+ def _prod_batch (self , dim : int ) -> LinearOperator :
99109 batch_shape = list (self .batch_shape )
100110 del batch_shape [dim ]
101111 return IdentityLinearOperator (
102112 diag_shape = self .diag_shape , batch_shape = torch .Size (batch_shape ), dtype = self .dtype , device = self .device
103113 )
104114
105- def _root_decomposition (self ):
115+ def _root_decomposition (self ) -> LinearOperator :
106116 return self .sqrt ()
107117
108- def _root_inv_decomposition (self , initial_vectors = None ):
118+ def _root_inv_decomposition (
119+ self ,
120+ initial_vectors : Optional [torch .Tensor ] = None ,
121+ test_vectors : Optional [torch .Tensor ] = None ,
122+ ) -> LinearOperator :
109123 return self .inverse ().sqrt ()
110124
111- def _size (self ):
125+ def _size (self ) -> torch . Size :
112126 return torch .Size ([* self ._batch_shape , self .diag_shape , self .diag_shape ])
113127
114128 @cached (name = "svd" )
@@ -118,10 +132,10 @@ def _svd(self) -> Tuple[LinearOperator, Tensor, LinearOperator]:
118132 def _symeig (self , eigenvectors : bool = False ) -> Tuple [Tensor , Optional [LinearOperator ]]:
119133 return self ._diag , self
120134
121- def _t_matmul (self , rhs ) :
135+ def _t_matmul (self , rhs : torch . Tensor ) -> LinearOperator :
122136 return self ._maybe_reshape_rhs (rhs )
123137
124- def _transpose_nonbatch (self ):
138+ def _transpose_nonbatch (self ) -> LinearOperator :
125139 return self
126140
127141 def _unsqueeze_batch (self , dim : int ) -> IdentityLinearOperator :
@@ -132,16 +146,18 @@ def _unsqueeze_batch(self, dim: int) -> IdentityLinearOperator:
132146 diag_shape = self .diag_shape , batch_shape = batch_shape , dtype = self .dtype , device = self .device
133147 )
134148
135- def abs (self ):
149+ def abs (self ) -> LinearOperator :
136150 return self
137151
138- def exp (self ):
152+ def exp (self ) -> LinearOperator :
139153 return self
140154
141- def inverse (self ):
155+ def inverse (self ) -> LinearOperator :
142156 return self
143157
144- def inv_quad_logdet (self , inv_quad_rhs = None , logdet = False , reduce_inv_quad = True ):
158+ def inv_quad_logdet (
159+ self , inv_quad_rhs : Optional [torch .Tensor ] = None , logdet : bool = False , reduce_inv_quad : bool = True
160+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
145161 # TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append)
146162 if inv_quad_rhs is None :
147163 inv_quad_term = torch .empty (0 , dtype = self .dtype , device = self .device )
@@ -158,12 +174,12 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
158174
159175 return inv_quad_term , logdet_term
160176
161- def log (self ):
177+ def log (self ) -> LinearOperator :
162178 return ZeroLinearOperator (
163179 * self ._batch_shape , self .diag_shape , self .diag_shape , dtype = self ._dtype , device = self ._device
164180 )
165181
166- def matmul (self , other ) :
182+ def matmul (self , other : Union [ torch . Tensor , LinearOperator ]) -> Union [ torch . Tensor , LinearOperator ] :
167183 is_vec = False
168184 if other .dim () == 1 :
169185 is_vec = True
@@ -173,31 +189,28 @@ def matmul(self, other):
173189 res = res .squeeze (- 1 )
174190 return res
175191
176- def solve (self , right_tensor , left_tensor = None ):
192+ def solve (self , right_tensor : torch . Tensor , left_tensor : Optional [ torch . Tensor ] = None ) -> torch . Tensor :
177193 res = self ._maybe_reshape_rhs (right_tensor )
178194 if left_tensor is not None :
179195 res = left_tensor @ res
180196 return res
181197
182- def sqrt (self ):
198+ def sqrt (self ) -> LinearOperator :
183199 return self
184200
185- def sqrt_inv_matmul (self , rhs , lhs = None ):
201+ def sqrt_inv_matmul (self , rhs : torch . Tensor , lhs : Optional [ torch . Tensor ] = None ) -> torch . Tensor :
186202 if lhs is None :
187203 return self ._maybe_reshape_rhs (rhs )
188204 else :
189205 sqrt_inv_matmul = lhs @ rhs
190206 inv_quad = lhs .pow (2 ).sum (dim = - 1 )
191207 return sqrt_inv_matmul , inv_quad
192208
193- def type (self , dtype ):
194- """
195- This method operates similarly to :func:`torch.Tensor.type`.
196- """
209+ def type (self , dtype : torch .dtype ) -> LinearOperator :
197210 return IdentityLinearOperator (
198211 diag_shape = self .diag_shape , batch_shape = self .batch_shape , dtype = dtype , device = self .device
199212 )
200213
201- def zero_mean_mvn_samples (self , num_samples ) :
214+ def zero_mean_mvn_samples (self , num_samples : int ) -> torch . Tensor :
202215 base_samples = torch .randn (num_samples , * self .shape [:- 1 ], dtype = self .dtype , device = self .device )
203216 return base_samples
0 commit comments