1414import numpy as np
1515import scipy .sparse
1616from numpy .lib .stride_tricks import as_strided
17+ from scipy .sparse import issparse , spmatrix
1718
1819import pytensor
1920from pytensor import _as_symbolic , as_symbolic
7071)
7172
7273
73- sparse_formats = ["csc" , "csr" ]
74-
75- """
76- Types of sparse matrices to use for testing.
77-
78- """
79- _mtypes = [scipy .sparse .csc_matrix , scipy .sparse .csr_matrix ]
80- # _mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix,
81- # sparse.lil_matrix, sparse.coo_matrix]
82- # * new class ``dia_matrix`` : the sparse DIAgonal format
83- # * new class ``bsr_matrix`` : the Block CSR format
84- _mtype_to_str = {scipy .sparse .csc_matrix : "csc" , scipy .sparse .csr_matrix : "csr" }
85-
86-
8774def _is_sparse_variable (x ):
8875 """
8976
@@ -134,7 +121,7 @@ def _is_dense(x):
134121 L{numpy.ndarray}).
135122
136123 """
137- if not isinstance (x , scipy . sparse . spmatrix | np .ndarray ):
124+ if not isinstance (x , spmatrix | np .ndarray ):
138125 raise NotImplementedError (
139126 "this function should only be called on "
140127 "sparse.scipy.sparse.spmatrix or "
@@ -144,7 +131,7 @@ def _is_dense(x):
144131 return isinstance (x , np .ndarray )
145132
146133
147- @_as_symbolic .register (scipy . sparse . spmatrix )
134+ @_as_symbolic .register (spmatrix )
148135def as_symbolic_sparse (x , ** kwargs ):
149136 return as_sparse_variable (x , ** kwargs )
150137
@@ -198,7 +185,7 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
198185
199186
200187def constant (x , name = None ):
201- if not isinstance (x , scipy . sparse . spmatrix ):
188+ if not isinstance (x , spmatrix ):
202189 raise TypeError ("sparse.constant must be called on a scipy.sparse.spmatrix" )
203190 try :
204191 return SparseConstant (
@@ -3337,7 +3324,7 @@ def perform(self, node, inp, out_):
33373324 x , y = inp
33383325 (out ,) = out_
33393326 rval = x .dot (y )
3340- if not scipy . sparse . issparse (rval ):
3327+ if not issparse (rval ):
33413328 rval = getattr (scipy .sparse , x .format + "_matrix" )(rval )
33423329 # x.dot call tocsr() that will "upcast" to ['int8', 'uint8', 'short',
33433330 # 'ushort', 'intc', 'uintc', 'longlong', 'ulonglong', 'single',
@@ -3604,7 +3591,7 @@ def perform(self, node, inputs, outputs):
36043591 # the following dot product can result in a scalar or
36053592 # a (1, 1) sparse matrix.
36063593 dot_val = np .dot (g_ab [i ], b [j ].T )
3607- if isinstance (dot_val , scipy . sparse . spmatrix ):
3594+ if isinstance (dot_val , spmatrix ):
36083595 dot_val = dot_val [0 , 0 ]
36093596 g_a_data [i_idx ] = dot_val
36103597 out [0 ] = g_a_data
@@ -3738,7 +3725,7 @@ def perform(self, node, inputs, outputs):
37383725 # the following dot product can result in a scalar or
37393726 # a (1, 1) sparse matrix.
37403727 dot_val = np .dot (g_ab [i ], b [j ].T )
3741- if isinstance (dot_val , scipy . sparse . spmatrix ):
3728+ if isinstance (dot_val , spmatrix ):
37423729 dot_val = dot_val [0 , 0 ]
37433730 g_a_data [j_idx ] = dot_val
37443731 out [0 ] = g_a_data
@@ -3955,9 +3942,9 @@ def make_node(self, x, y):
39553942 # Sparse dot product should have at least one sparse variable
39563943 # as input. If the other one is not sparse, it has to be converted
39573944 # into a tensor.
3958- if isinstance (x , scipy . sparse . spmatrix ):
3945+ if isinstance (x , spmatrix ):
39593946 x = as_sparse_variable (x )
3960- if isinstance (y , scipy . sparse . spmatrix ):
3947+ if isinstance (y , spmatrix ):
39613948 y = as_sparse_variable (y )
39623949
39633950 x_is_sparse_var = _is_sparse_variable (x )
@@ -4147,7 +4134,7 @@ def perform(self, node, inputs, outputs):
41474134 raise TypeError (x )
41484135
41494136 rval = x * y
4150- if isinstance (rval , scipy . sparse . spmatrix ):
4137+ if isinstance (rval , spmatrix ):
41514138 rval = rval .toarray ()
41524139 if rval .dtype == alpha .dtype :
41534140 rval *= alpha # Faster because operation is inplace
0 commit comments