@@ -937,15 +937,24 @@ def __init__(self, *, assume_a="gen", **kwargs):
937937
938938 def perform (self , node , inputs , outputs ):
939939 a , b = inputs
940- outputs [0 ][0 ] = scipy_linalg .solve (
941- a = a ,
942- b = b ,
943- lower = self .lower ,
944- check_finite = self .check_finite ,
945- assume_a = self .assume_a ,
946- overwrite_a = self .overwrite_a ,
947- overwrite_b = self .overwrite_b ,
948- )
940+ if self .assume_a == "tridiagonal" :
941+ [dl , d , du ] = (a .diagonal (offset = o ) for o in (- 1 , 0 , 1 ))
942+ _gttrf , _gttrs = scipy_linalg .get_lapack_funcs (
943+ ("gttrf" , "gttrs" ), dtype = node .outputs [0 ].type .dtype
944+ )
945+ dl , d , du , du2 , ipiv , _ = _gttrf (dl , d , du )
946+ x , _ = _gttrs (dl , d , du , du2 , ipiv , b , overwrite_b = self .overwrite_b )
947+ outputs [0 ][0 ] = x
948+ else :
949+ outputs [0 ][0 ] = scipy_linalg .solve (
950+ a = a ,
951+ b = b ,
952+ lower = self .lower ,
953+ check_finite = self .check_finite ,
954+ assume_a = self .assume_a ,
955+ overwrite_a = self .overwrite_a ,
956+ overwrite_b = self .overwrite_b ,
957+ )
949958
950959 def inplace_on_inputs (self , allowed_inplace_inputs : list [int ]) -> "Op" :
951960 if not allowed_inplace_inputs :
0 commit comments