1- import scipy
21import numpy as np
2+ import scipy
33from scipy .linalg import get_lapack_funcs
44
5- from pytensor .graph import Op , Apply
5+ from pytensor .graph import Apply , Op
66from pytensor .tensor .basic import as_tensor , diagonal
7- from pytensor .tensor .type import tensor , vector
87from pytensor .tensor .blockwise import Blockwise
98from pytensor .tensor .slinalg import Solve
9+ from pytensor .tensor .type import tensor , vector
1010
1111
1212class LUFactorTridiagonal (Op ):
1313 """Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
14- __props__ = ("overwrite_dl" , "overwrite_d" , "overwrite_du" ,)
14+
15+ __props__ = (
16+ "overwrite_dl" ,
17+ "overwrite_d" ,
18+ "overwrite_du" ,
19+ )
1520 gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
1621
1722 def __init__ (self , overwrite_dl = False , overwrite_d = False , overwrite_du = False ):
@@ -29,11 +34,8 @@ def make_node(self, dl, d, du):
2934 ndl , nd , ndu = (inp .type .shape [- 1 ] for inp in (dl , d , du ))
3035 n = (
3136 ndl + 1
32- if ndl is not None else (
33- nd if nd is not None else (
34- ndu + 1 if ndu is not None else None
35- )
36- )
37+ if ndl is not None
38+ else (nd if nd is not None else (ndu + 1 if ndu is not None else None ))
3739 )
3840 dummy_arrays = [np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du )]
3941 out_dtype = get_lapack_funcs ("gttrf" , dummy_arrays ).dtype
@@ -63,6 +65,7 @@ def perform(self, node, inputs, output_storage):
6365
6466class SolveLUFactorTridiagonal (Op ):
6567 """Solve a system of linear equations with a tridiagonal coefficient matrix."""
68+
6669 __props__ = ("b_ndim" , "overwrite_b" )
6770
6871 def __init__ (self , b_ndim : int , overwrite_b = False ):
@@ -84,21 +87,30 @@ def make_node(self, dl, d, du, du2, ipiv, b):
8487 if not all (inp .type .ndim == 1 for inp in (dl , d , du , du2 , ipiv )):
8588 raise ValueError ("Inputs must be vectors" )
8689
87- ndl , nd , ndu , ndu2 , nipiv = (inp .type .shape [- 1 ] for inp in (dl , d , du , du2 , ipiv ))
90+ ndl , nd , ndu , ndu2 , nipiv = (
91+ inp .type .shape [- 1 ] for inp in (dl , d , du , du2 , ipiv )
92+ )
8893 nb = b .type .shape [0 ]
8994 n = (
9095 ndl + 1
91- if ndl is not None else (
92- nd if nd is not None else (
93- ndu + 1 if ndu is not None else (
94- ndu2 + 2 if ndu2 is not None else (
95- nipiv if nipiv is not None else nb
96- )
96+ if ndl is not None
97+ else (
98+ nd
99+ if nd is not None
100+ else (
101+ ndu + 1
102+ if ndu is not None
103+ else (
104+ ndu2 + 2
105+ if ndu2 is not None
106+ else (nipiv if nipiv is not None else nb )
97107 )
98108 )
99109 )
100110 )
101- dummy_arrays = [np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du , du2 , ipiv )]
111+ dummy_arrays = [
112+ np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du , du2 , ipiv )
113+ ]
102114 # Seems to always be float64?
103115 out_dtype = get_lapack_funcs ("gttrs" , dummy_arrays ).dtype
104116 if self .b_ndim == 1 :
@@ -111,14 +123,13 @@ def make_node(self, dl, d, du, du2, ipiv, b):
111123
112124 def perform (self , node , inputs , output_storage ):
113125 gttrs = get_lapack_funcs ("gttrs" , dtype = node .outputs [0 ].type .dtype )
114- x , _ = gttrs (
115- * inputs , overwrite_b = self .overwrite_b
116- )
126+ x , _ = gttrs (* inputs , overwrite_b = self .overwrite_b )
117127 output_storage [0 ][0 ] = x
118128
119129
120130class SolveTridiagonal (Op ):
121131 """Solve a system of linear equations with a tridiagonal dense matrix."""
132+
122133 __props__ = ("b_ndim" , "overwrite_b" )
123134
124135 def __init__ (self , * , b_ndim : int , overwrite_b : bool = False ):
@@ -141,7 +152,9 @@ def make_node(self, dl, d, du, b):
141152 raise TypeError ("Diagonals must have the same dtype" )
142153
143154 if b .type .ndim != self .b_ndim :
144- raise ValueError (f"Number of dimensions of b does not match promised { self .b_ndim } " )
155+ raise ValueError (
156+ f"Number of dimensions of b does not match promised { self .b_ndim } "
157+ )
145158
146159 out_dtype = scipy .linalg .solve (
147160 np .eye ((3 ), dtype = d .type .dtype ),
@@ -156,13 +169,14 @@ def L_op(self, node, inputs, outputs, output_grads):
156169
157170 def perform (self , node , inputs , output_storage ):
158171 [dl , d , du , b ] = inputs
159- _gttrf , _gttrs = get_lapack_funcs (('gttrf' , 'gttrs' ), dtype = node .outputs [0 ].type .dtype )
172+ _gttrf , _gttrs = get_lapack_funcs (
173+ ("gttrf" , "gttrs" ), dtype = node .outputs [0 ].type .dtype
174+ )
160175
161176 dl , d , du , du2 , ipiv , _ = _gttrf (dl , d , du )
162177 x , _ = _gttrs (dl , d , du , du2 , ipiv , b , overwrite_b = self .overwrite_b )
163178 output_storage [0 ][0 ] = x
164179
165-
166180 def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
167181 if 3 not in allowed_inplace_inputs :
168182 return self
@@ -186,6 +200,7 @@ def solve_tridiagonal_from_full_A_b(a, b, b_ndim: int, transposed: bool):
186200 dl , d , du = (diagonal (a , offset = o , axis1 = - 2 , axis2 = - 1 ) for o in (- 1 , 0 , 1 ))
187201 return Blockwise (SolveTridiagonal (b_ndim = b_ndim ))(dl , d , du )
188202
203+
189204def split_solve_tridiagonal (node ):
190205 """Split a generic solve tridiagonal system into the 3 atomic steps:
191206 1. Diagonal extractions
@@ -198,11 +213,21 @@ def split_solve_tridiagonal(node):
198213 core_op = node .op .core_op
199214 assert isinstance (core_op , Solve ) and core_op .assume_a == "tridiagonal"
200215 a , b = node .inputs
201- dl , d , du , du2 , ipiv = decompose_of_solve_tridiagonal (a )
202- return Blockwise (SolveLUFactorTridiagonal (b_ndim = node .op .core_op .b_ndim ))(dl , d , du , du2 , ipiv , b )
216+ a_decomp = decompose_of_solve_tridiagonal (a )
217+ return solve_decomposed_tridiagonal (a_decomp , b , b_ndim = core_op .b_ndim )
218+
203219
204220def decompose_of_solve_tridiagonal (a ):
205221 # Return the decomposition of A implied by a solve tridiagonal
206222 dl , d , du = (diagonal (a , offset = o , axis1 = - 2 , axis2 = - 1 ) for o in (- 1 , 0 , 1 ))
207223 dl , d , du , du2 , ipiv = Blockwise (LUFactorTridiagonal ())(dl , d , du )
208224 return dl , d , du , du2 , ipiv
225+
226+
227+ def decompose_tridiagonals (dl , d , du ):
228+ return Blockwise (LUFactorTridiagonal ())(dl , d , du )
229+
230+
231+ def solve_decomposed_tridiagonal (a_diagonals , b , * , b_ndim : int ):
232+ dl , d , du , du2 , ipiv = a_diagonals
233+ return Blockwise (SolveLUFactorTridiagonal (b_ndim = b_ndim ))(dl , d , du , du2 , ipiv , b )
0 commit comments