55from typing import Literal , cast
66
77import numpy as np
8- import scipy .linalg
8+ import scipy .linalg as scipy_linalg
99
1010import pytensor
1111import pytensor .tensor as pt
@@ -58,7 +58,7 @@ def make_node(self, x):
5858 f"Cholesky only allowed on matrix (2-D) inputs, got { x .type .ndim } -D input"
5959 )
6060 # Call scipy to find output dtype
61- dtype = scipy . linalg .cholesky (np .eye (1 , dtype = x .type .dtype )).dtype
61+ dtype = scipy_linalg .cholesky (np .eye (1 , dtype = x .type .dtype )).dtype
6262 return Apply (self , [x ], [tensor (shape = x .type .shape , dtype = dtype )])
6363
6464 def perform (self , node , inputs , outputs ):
@@ -68,21 +68,21 @@ def perform(self, node, inputs, outputs):
6868 # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
6969 # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
7070 if self .overwrite_a and x .flags ["C_CONTIGUOUS" ]:
71- out [0 ] = scipy . linalg .cholesky (
71+ out [0 ] = scipy_linalg .cholesky (
7272 x .T ,
7373 lower = not self .lower ,
7474 check_finite = self .check_finite ,
7575 overwrite_a = True ,
7676 ).T
7777 else :
78- out [0 ] = scipy . linalg .cholesky (
78+ out [0 ] = scipy_linalg .cholesky (
7979 x ,
8080 lower = self .lower ,
8181 check_finite = self .check_finite ,
8282 overwrite_a = self .overwrite_a ,
8383 )
8484
85- except scipy . linalg .LinAlgError :
85+ except scipy_linalg .LinAlgError :
8686 if self .on_error == "raise" :
8787 raise
8888 else :
@@ -334,7 +334,7 @@ def __init__(self, **kwargs):
334334
335335 def perform (self , node , inputs , output_storage ):
336336 C , b = inputs
337- rval = scipy . linalg .cho_solve (
337+ rval = scipy_linalg .cho_solve (
338338 (C , self .lower ),
339339 b ,
340340 check_finite = self .check_finite ,
@@ -401,7 +401,7 @@ def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
401401
402402 def perform (self , node , inputs , outputs ):
403403 A , b = inputs
404- outputs [0 ][0 ] = scipy . linalg .solve_triangular (
404+ outputs [0 ][0 ] = scipy_linalg .solve_triangular (
405405 A ,
406406 b ,
407407 lower = self .lower ,
@@ -502,7 +502,7 @@ def __init__(self, *, assume_a="gen", **kwargs):
502502
503503 def perform (self , node , inputs , outputs ):
504504 a , b = inputs
505- outputs [0 ][0 ] = scipy . linalg .solve (
505+ outputs [0 ][0 ] = scipy_linalg .solve (
506506 a = a ,
507507 b = b ,
508508 lower = self .lower ,
@@ -619,9 +619,9 @@ def make_node(self, a, b):
619619 def perform (self , node , inputs , outputs ):
620620 (w ,) = outputs
621621 if len (inputs ) == 2 :
622- w [0 ] = scipy . linalg .eigvalsh (a = inputs [0 ], b = inputs [1 ], lower = self .lower )
622+ w [0 ] = scipy_linalg .eigvalsh (a = inputs [0 ], b = inputs [1 ], lower = self .lower )
623623 else :
624- w [0 ] = scipy . linalg .eigvalsh (a = inputs [0 ], b = None , lower = self .lower )
624+ w [0 ] = scipy_linalg .eigvalsh (a = inputs [0 ], b = None , lower = self .lower )
625625
626626 def grad (self , inputs , g_outputs ):
627627 a , b = inputs
@@ -675,7 +675,7 @@ def make_node(self, a, b, gw):
675675
676676 def perform (self , node , inputs , outputs ):
677677 (a , b , gw ) = inputs
678- w , v = scipy . linalg .eigh (a , b , lower = self .lower )
678+ w , v = scipy_linalg .eigh (a , b , lower = self .lower )
679679 gA = v .dot (np .diag (gw ).dot (v .T ))
680680 gB = - v .dot (np .diag (gw * w ).dot (v .T ))
681681
@@ -718,7 +718,7 @@ def make_node(self, A):
718718 def perform (self , node , inputs , outputs ):
719719 (A ,) = inputs
720720 (expm ,) = outputs
721- expm [0 ] = scipy . linalg .expm (A )
721+ expm [0 ] = scipy_linalg .expm (A )
722722
723723 def grad (self , inputs , outputs ):
724724 (A ,) = inputs
@@ -758,8 +758,8 @@ def perform(self, node, inputs, outputs):
758758 # this expression.
759759 (A , gA ) = inputs
760760 (out ,) = outputs
761- w , V = scipy . linalg .eig (A , right = True )
762- U = scipy . linalg .inv (V ).T
761+ w , V = scipy_linalg .eig (A , right = True )
762+ U = scipy_linalg .inv (V ).T
763763
764764 exp_w = np .exp (w )
765765 X = np .subtract .outer (exp_w , exp_w ) / np .subtract .outer (w , w )
@@ -800,7 +800,7 @@ def perform(self, node, inputs, output_storage):
800800 X = output_storage [0 ]
801801
802802 out_dtype = node .outputs [0 ].type .dtype
803- X [0 ] = scipy . linalg .solve_continuous_lyapunov (A , B ).astype (out_dtype )
803+ X [0 ] = scipy_linalg .solve_continuous_lyapunov (A , B ).astype (out_dtype )
804804
805805 def infer_shape (self , fgraph , node , shapes ):
806806 return [shapes [0 ]]
@@ -870,7 +870,7 @@ def perform(self, node, inputs, output_storage):
870870 X = output_storage [0 ]
871871
872872 out_dtype = node .outputs [0 ].type .dtype
873- X [0 ] = scipy . linalg .solve_discrete_lyapunov (A , B , method = "bilinear" ).astype (
873+ X [0 ] = scipy_linalg .solve_discrete_lyapunov (A , B , method = "bilinear" ).astype (
874874 out_dtype
875875 )
876876
@@ -992,7 +992,7 @@ def perform(self, node, inputs, output_storage):
992992 Q = 0.5 * (Q + Q .T )
993993
994994 out_dtype = node .outputs [0 ].type .dtype
995- X [0 ] = scipy . linalg .solve_discrete_are (A , B , Q , R ).astype (out_dtype )
995+ X [0 ] = scipy_linalg .solve_discrete_are (A , B , Q , R ).astype (out_dtype )
996996
997997 def infer_shape (self , fgraph , node , shapes ):
998998 return [shapes [0 ]]
@@ -1118,7 +1118,7 @@ def make_node(self, *matrices):
11181118
11191119 def perform (self , node , inputs , output_storage , params = None ):
11201120 dtype = node .outputs [0 ].type .dtype
1121- output_storage [0 ][0 ] = scipy . linalg .block_diag (* inputs ).astype (dtype )
1121+ output_storage [0 ][0 ] = scipy_linalg .block_diag (* inputs ).astype (dtype )
11221122
11231123
11241124def block_diag (* matrices : TensorVariable ):
0 commit comments