33from scipy .linalg import get_lapack_funcs
44
55from pytensor .graph import Op , Apply
6- from pytensor .tensor import as_tensor , tensor , diagonal
6+ from pytensor .tensor .basic import as_tensor , diagonal
7+ from pytensor .tensor .type import tensor , vector
78from pytensor .tensor .blockwise import Blockwise
9+ from pytensor .tensor .slinalg import Solve
810
911
1012class LUFactorTridiagonal (Op ):
1113 """Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
1214 __props__ = ("overwrite_dl" , "overwrite_d" , "overwrite_du" ,)
13- _gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
15+ gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
1416
1517 def __init__ (self , overwrite_dl = False , overwrite_d = False , overwrite_du = False ):
1618 self .overwrite_dl = overwrite_dl
@@ -19,33 +21,34 @@ def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
1921 super ().__init__ ()
2022
2123 def make_node (self , dl , d , du ):
22- dl , d , du = map (as_tensor , dl , d , du )
24+ dl , d , du = map (as_tensor , ( dl , d , du ) )
2325
24- if not all (inp .type .ndim == 1 for inp in (dl , d , du ))
26+ if not all (inp .type .ndim == 1 for inp in (dl , d , du )):
2527 raise ValueError ("Diagonals must be vectors" )
2628
2729 ndl , nd , ndu = (inp .type .shape [- 1 ] for inp in (dl , d , du ))
2830 n = (
2931 ndl + 1
3032 if ndl is not None else (
31- n if n is not None else (
32- ndu + 1 if nu is not None else None
33+ nd if nd is not None else (
34+ ndu + 1 if ndu is not None else None
3335 )
3436 )
3537 )
3638 dummy_arrays = [np .zeros ((), dtype = inp .type .dtype ) for inp in (dl , d , du )]
3739 out_dtype = get_lapack_funcs ("gttrf" , dummy_arrays ).dtype
3840 outputs = [
39- vector (shape = (shape = ( None if n is None else n - 1 ,), dtype = out_dtype ),
41+ vector (shape = (None if n is None else ( n - 1 ) ,), dtype = out_dtype ),
4042 vector (shape = (n ,), dtype = out_dtype ),
4143 vector (shape = (None if n is None else n - 1 ,), dtype = out_dtype ),
4244 vector (shape = (None if n is None else n - 2 ,), dtype = out_dtype ),
4345 vector (shape = (n ,), dtype = np .int32 ),
4446 ]
47+ return Apply (self , [dl , d , du ], outputs )
4548
4649 def perform (self , node , inputs , output_storage ):
4750 gttrf = get_lapack_funcs ("gttrf" , dtype = node .outputs [0 ].type .dtype )
48- dl , d , du , du2 , ipiv , _ = _gttrf (
51+ dl , d , du , du2 , ipiv , _ = gttrf (
4952 * inputs ,
5053 overwrite_dl = self .overwrite_dl ,
5154 overwrite_d = self .overwrite_d ,
@@ -68,26 +71,26 @@ def __init__(self, b_ndim: int, overwrite_b=False):
6871 self .b_ndim = b_ndim
6972 self .overwrite_b = overwrite_b
7073 if b_ndim == 1 :
71- _gufunc_signature = "(dl ),(d ),(dl ),(du2 ),(d ),(d )- > (d )
74+ self . gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
7275 else :
73- _gufunc_signature = "(dl ),(d ),(dl ),(du2 ),(d ),(d ,rhs )- > (d ,rhs )
76+ self . gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
7477
7578 def make_node (self , dl , d , du , du2 , ipiv , b ):
76- dl , d , du , du2 , ipiv , b = map (as_tensor , dl , d , du , du2 , ipiv , b )
79+ dl , d , du , du2 , ipiv , b = map (as_tensor , ( dl , d , du , du2 , ipiv , b ) )
7780
7881 if b .type .ndim != self .b_ndim :
7982 raise ValueError ("Wrang number of dimensions for input b." )
8083
81- if not all (inp .type .ndim == 1 for inp in (dl , d , du , du2 , ipiv ))
84+ if not all (inp .type .ndim == 1 for inp in (dl , d , du , du2 , ipiv )):
8285 raise ValueError ("Inputs must be vectors" )
8386
8487 ndl , nd , ndu , ndu2 , nipiv = (inp .type .shape [- 1 ] for inp in (dl , d , du , du2 , ipiv ))
8588 nb = b .type .shape [0 ]
8689 n = (
8790 ndl + 1
8891 if ndl is not None else (
89- n if n is not None else (
90- ndu + 1 if nu is not None else (
92+ nd if nd is not None else (
93+ ndu + 1 if ndu is not None else (
9194 ndu2 + 2 if ndu2 is not None else (
9295 nipiv if nipiv is not None else nb
9396 )
@@ -101,14 +104,14 @@ def make_node(self, dl, d, du, du2, ipiv, b):
101104 if self .b_ndim == 1 :
102105 output_shape = (n ,)
103106 else :
104- output_shape = (n , n .type .shape [- 1 ])
107+ output_shape = (n , b .type .shape [- 1 ])
105108
106- outputs = [vector (shape = output_shape , dtype = out_dtype )]
109+ outputs = [tensor (shape = output_shape , dtype = out_dtype )]
107110 return Apply (self , [dl , d , du , du2 , ipiv , b ], outputs )
108111
109112 def perform (self , node , inputs , output_storage ):
110113 gttrs = get_lapack_funcs ("gttrs" , dtype = node .outputs [0 ].type .dtype )
111- x , _ = _gttrs (
114+ x , _ = gttrs (
112115 * inputs , overwrite_b = self .overwrite_b
113116 )
114117 output_storage [0 ][0 ] = x
@@ -149,7 +152,7 @@ def make_node(self, dl, d, du, b):
149152 return Apply (self , [dl , d , du , b ], [out ])
150153
151154 def L_op (self , node , inputs , outputs , output_grads ):
152- # TODO
155+ pass
153156
154157 def perform (self , node , inputs , output_storage ):
155158 [dl , d , du , b ] = inputs
@@ -193,8 +196,13 @@ def split_solve_tridiagonal(node):
193196 """
194197 assert isinstance (node .op , Blockwise )
195198 core_op = node .op .core_op
196- assert isinstance (core_op , Solve ) and core . op .assume_a == "tridiagonal"
199+ assert isinstance (core_op , Solve ) and core_op .assume_a == "tridiagonal"
197200 a , b = node .inputs
201+ dl , d , du , du2 , ipiv = decompose_of_solve_tridiagonal (a )
202+ return Blockwise (SolveLUFactorTridiagonal (b_ndim = node .op .core_op .b_ndim ))(dl , d , du , du2 , ipiv , b )
203+
204+ def decompose_of_solve_tridiagonal (a ):
205+ # Return the decomposition of A implied by a solve tridiagonal
198206 dl , d , du = (diagonal (a , offset = o , axis1 = - 2 , axis2 = - 1 ) for o in (- 1 , 0 , 1 ))
199207 dl , d , du , du2 , ipiv = Blockwise (LUFactorTridiagonal ())(dl , d , du )
200- return Blockwise ( SolveLUFactorTridiagonal ( b_ndim = node . op . core . op . b_ndim ))( dl , d , du )( dl , d , du , du2 , ipiv )
208+ return dl , d , du , du2 , ipiv
0 commit comments