11#!/usr/bin/env python3
22
3- from typing import Optional , Tuple , Union
3+ from __future__ import annotations
44
55import linear_operator
66import torch
2828from .mlls import ExactMarginalLogLikelihood
2929from .module import Module
3030
31- Anysor = Union [ LinearOperator , Tensor ]
31+ Anysor = LinearOperator | Tensor
3232
3333
3434def add_diagonal (input : Anysor , diag : Tensor ) -> LinearOperator :
@@ -58,7 +58,7 @@ def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor:
5858 return linear_operator .add_jitter (input = input , jitter_val = jitter_val )
5959
6060
61- def diagonalization (input : Anysor , method : Optional [ str ] = None ) -> Tuple [Tensor , Tensor ]:
61+ def diagonalization (input : Anysor , method : str | None = None ) -> tuple [Tensor , Tensor ]:
6262 r"""
6363 Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices).
6464 :math:`\mathbf A`.
@@ -74,7 +74,7 @@ def diagonalization(input: Anysor, method: Optional[str] = None) -> Tuple[Tensor
7474
7575
7676def dsmm (
77- sparse_mat : Union [ torch .sparse .HalfTensor , torch .sparse .FloatTensor , torch .sparse .DoubleTensor ] ,
77+ sparse_mat : torch .sparse .HalfTensor | torch .sparse .FloatTensor | torch .sparse .DoubleTensor ,
7878 dense_mat : Tensor ,
7979) -> Tensor :
8080 r"""
@@ -117,10 +117,10 @@ def inv_quad(input: Anysor, inv_quad_rhs: Tensor, reduce_inv_quad: bool = True)
117117
118118def inv_quad_logdet (
119119 input : Anysor ,
120- inv_quad_rhs : Optional [ Tensor ] = None ,
120+ inv_quad_rhs : Tensor | None = None ,
121121 logdet : bool = False ,
122122 reduce_inv_quad : bool = True ,
123- ) -> Tuple [Tensor , Tensor ]:
123+ ) -> tuple [Tensor , Tensor ]:
124124 r"""
125125 Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
126126 However, calling this method is far more efficient and stable than calling each method independently.
@@ -146,9 +146,9 @@ def inv_quad_logdet(
146146def pivoted_cholesky (
147147 input : Anysor ,
148148 rank : int ,
149- error_tol : Optional [ float ] = None ,
149+ error_tol : float | None = None ,
150150 return_pivots : bool = False ,
151- ) -> Union [ Tensor , Tuple [Tensor , Tensor ] ]:
151+ ) -> Tensor | tuple [Tensor , Tensor ]:
152152 r"""
153153 Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
154154 :math:`\mathbf L \mathbf L^\top = \mathbf A`.
@@ -173,7 +173,7 @@ def pivoted_cholesky(
173173 return linear_operator .pivoted_cholesky (input = input , rank = rank , return_pivots = return_pivots )
174174
175175
176- def root_decomposition (input : Anysor , method : Optional [ str ] = None ) -> LinearOperator :
176+ def root_decomposition (input : Anysor , method : str | None = None ) -> LinearOperator :
177177 r"""
178178 Returns a (usually low-rank) root decomposition linear operator of the
179179 positive definite matrix (or batch of matrices) :math:`\mathbf A`.
@@ -190,9 +190,9 @@ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOpe
190190
191191def root_inv_decomposition (
192192 input : Anysor ,
193- initial_vectors : Optional [ Tensor ] = None ,
194- test_vectors : Optional [ Tensor ] = None ,
195- method : Optional [ str ] = None ,
193+ initial_vectors : Tensor | None = None ,
194+ test_vectors : Tensor | None = None ,
195+ method : str | None = None ,
196196) -> LinearOperator :
197197 r"""
198198 Returns a (usually low-rank) inverse root decomposition linear operator
@@ -217,7 +217,7 @@ def root_inv_decomposition(
217217 )
218218
219219
220- def solve (input : Anysor , rhs : Tensor , lhs : Optional [ Tensor ] = None ) -> Tensor :
220+ def solve (input : Anysor , rhs : Tensor , lhs : Tensor | None = None ) -> Tensor :
221221 r"""
222222 Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`,
223223 computes a linear solve with right hand side :math:`\mathbf R`:
@@ -249,7 +249,7 @@ def solve(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
249249 return linear_operator .solve (input = input , rhs = rhs , lhs = lhs )
250250
251251
252- def sqrt_inv_matmul (input : Anysor , rhs : Tensor , lhs : Optional [ Tensor ] = None ) -> Tensor :
252+ def sqrt_inv_matmul (input : Anysor , rhs : Tensor , lhs : Tensor | None = None ) -> Tensor :
253253 r"""
254254 Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`
255255 and a right hand size :math:`\mathbf R`,
0 commit comments