Skip to content

Commit 948b079

Browse files
committed
Use more specialized scipy sparse imports
1 parent 065d5c2 commit 948b079

File tree

7 files changed

+77
-94
lines changed

7 files changed

+77
-94
lines changed

pytensor/sparse/basic.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import scipy.sparse
1616
from numpy.lib.stride_tricks import as_strided
17+
from scipy.sparse import issparse, spmatrix
1718

1819
import pytensor
1920
from pytensor import _as_symbolic, as_symbolic
@@ -70,20 +71,6 @@
7071
)
7172

7273

73-
sparse_formats = ["csc", "csr"]
74-
75-
"""
76-
Types of sparse matrices to use for testing.
77-
78-
"""
79-
_mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
80-
# _mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix,
81-
# sparse.lil_matrix, sparse.coo_matrix]
82-
# * new class ``dia_matrix`` : the sparse DIAgonal format
83-
# * new class ``bsr_matrix`` : the Block CSR format
84-
_mtype_to_str = {scipy.sparse.csc_matrix: "csc", scipy.sparse.csr_matrix: "csr"}
85-
86-
8774
def _is_sparse_variable(x):
8875
"""
8976
@@ -134,7 +121,7 @@ def _is_dense(x):
134121
L{numpy.ndarray}).
135122
136123
"""
137-
if not isinstance(x, scipy.sparse.spmatrix | np.ndarray):
124+
if not isinstance(x, spmatrix | np.ndarray):
138125
raise NotImplementedError(
139126
"this function should only be called on "
140127
"sparse.scipy.sparse.spmatrix or "
@@ -144,7 +131,7 @@ def _is_dense(x):
144131
return isinstance(x, np.ndarray)
145132

146133

147-
@_as_symbolic.register(scipy.sparse.spmatrix)
134+
@_as_symbolic.register(spmatrix)
148135
def as_symbolic_sparse(x, **kwargs):
149136
return as_sparse_variable(x, **kwargs)
150137

@@ -198,7 +185,7 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
198185

199186

200187
def constant(x, name=None):
201-
if not isinstance(x, scipy.sparse.spmatrix):
188+
if not isinstance(x, spmatrix):
202189
raise TypeError("sparse.constant must be called on a scipy.sparse.spmatrix")
203190
try:
204191
return SparseConstant(
@@ -3337,7 +3324,7 @@ def perform(self, node, inp, out_):
33373324
x, y = inp
33383325
(out,) = out_
33393326
rval = x.dot(y)
3340-
if not scipy.sparse.issparse(rval):
3327+
if not issparse(rval):
33413328
rval = getattr(scipy.sparse, x.format + "_matrix")(rval)
33423329
# x.dot call tocsr() that will "upcast" to ['int8', 'uint8', 'short',
33433330
# 'ushort', 'intc', 'uintc', 'longlong', 'ulonglong', 'single',
@@ -3604,7 +3591,7 @@ def perform(self, node, inputs, outputs):
36043591
# the following dot product can result in a scalar or
36053592
# a (1, 1) sparse matrix.
36063593
dot_val = np.dot(g_ab[i], b[j].T)
3607-
if isinstance(dot_val, scipy.sparse.spmatrix):
3594+
if isinstance(dot_val, spmatrix):
36083595
dot_val = dot_val[0, 0]
36093596
g_a_data[i_idx] = dot_val
36103597
out[0] = g_a_data
@@ -3738,7 +3725,7 @@ def perform(self, node, inputs, outputs):
37383725
# the following dot product can result in a scalar or
37393726
# a (1, 1) sparse matrix.
37403727
dot_val = np.dot(g_ab[i], b[j].T)
3741-
if isinstance(dot_val, scipy.sparse.spmatrix):
3728+
if isinstance(dot_val, spmatrix):
37423729
dot_val = dot_val[0, 0]
37433730
g_a_data[j_idx] = dot_val
37443731
out[0] = g_a_data
@@ -3955,9 +3942,9 @@ def make_node(self, x, y):
39553942
# Sparse dot product should have at least one sparse variable
39563943
# as input. If the other one is not sparse, it has to be converted
39573944
# into a tensor.
3958-
if isinstance(x, scipy.sparse.spmatrix):
3945+
if isinstance(x, spmatrix):
39593946
x = as_sparse_variable(x)
3960-
if isinstance(y, scipy.sparse.spmatrix):
3947+
if isinstance(y, spmatrix):
39613948
y = as_sparse_variable(y)
39623949

39633950
x_is_sparse_var = _is_sparse_variable(x)
@@ -4147,7 +4134,7 @@ def perform(self, node, inputs, outputs):
41474134
raise TypeError(x)
41484135

41494136
rval = x * y
4150-
if isinstance(rval, scipy.sparse.spmatrix):
4137+
if isinstance(rval, spmatrix):
41514138
rval = rval.toarray()
41524139
if rval.dtype == alpha.dtype:
41534140
rval *= alpha # Faster because operation is inplace

pytensor/sparse/rewriting.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
import scipy
2+
from scipy.sparse import csc_matrix, csr_matrix
33

44
import pytensor
55
import pytensor.scalar as ps
@@ -279,9 +279,7 @@ def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
279279
def perform(self, node, inputs, outputs):
280280
(a_val, a_ind, a_ptr, a_nrows, b) = inputs
281281
(out,) = outputs
282-
a = scipy.sparse.csc_matrix(
283-
(a_val, a_ind, a_ptr), (a_nrows, b.shape[0]), copy=False
284-
)
282+
a = csc_matrix((a_val, a_ind, a_ptr), (a_nrows, b.shape[0]), copy=False)
285283
# out[0] = a.dot(b)
286284
out[0] = np.asarray(a * b, dtype=node.outputs[0].type.dtype)
287285
assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense
@@ -478,7 +476,7 @@ def make_node(self, a_val, a_ind, a_ptr, b):
478476
def perform(self, node, inputs, outputs):
479477
(a_val, a_ind, a_ptr, b) = inputs
480478
(out,) = outputs
481-
a = scipy.sparse.csr_matrix(
479+
a = csr_matrix(
482480
(a_val, a_ind, a_ptr), (len(a_ptr) - 1, b.shape[0]), copy=True
483481
) # use view_map before setting this to False
484482
# out[0] = a.dot(b)

pytensor/sparse/sharedvar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22

3-
import scipy.sparse
3+
from scipy.sparse import spmatrix
44

55
from pytensor.compile import shared_constructor
66
from pytensor.sparse.basic import SparseTensorType, SparseVariable
@@ -13,7 +13,7 @@ def format(self):
1313
return self.type.format
1414

1515

16-
@shared_constructor.register(scipy.sparse.spmatrix)
16+
@shared_constructor.register(spmatrix)
1717
def sparse_constructor(
1818
value, name=None, strict=False, allow_downcast=None, borrow=False, format=None
1919
):

pytensor/sparse/type.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Literal
33

44
import numpy as np
5-
import scipy.sparse
5+
from scipy.sparse import bsr_matrix, csc_matrix, csr_matrix, issparse, spmatrix
66

77
import pytensor
88
from pytensor import scalar as ps
@@ -23,14 +23,14 @@ def _is_sparse(x):
2323
True iff x is a L{scipy.sparse.spmatrix} (and not a L{numpy.ndarray}).
2424
2525
"""
26-
if not isinstance(x, scipy.sparse.spmatrix | np.ndarray | tuple | list):
26+
if not isinstance(x, spmatrix | np.ndarray | tuple | list):
2727
raise NotImplementedError(
2828
"this function should only be called on "
2929
"sparse.scipy.sparse.spmatrix or "
3030
"numpy.ndarray, not,",
3131
x,
3232
)
33-
return isinstance(x, scipy.sparse.spmatrix)
33+
return isinstance(x, spmatrix)
3434

3535

3636
class SparseTensorType(TensorType, HasDataType):
@@ -44,9 +44,9 @@ class SparseTensorType(TensorType, HasDataType):
4444

4545
__props__ = ("dtype", "format", "shape")
4646
format_cls = {
47-
"csr": scipy.sparse.csr_matrix,
48-
"csc": scipy.sparse.csc_matrix,
49-
"bsr": scipy.sparse.bsr_matrix,
47+
"csr": csr_matrix,
48+
"csc": csc_matrix,
49+
"bsr": bsr_matrix,
5050
}
5151
dtype_specs_map = {
5252
"float32": (float, "npy_float32", "NPY_FLOAT32"),
@@ -187,7 +187,7 @@ def values_eq_approx(self, a, b, eps=1e-6):
187187
# WARNING: equality comparison of sparse matrices is not fast or easy
188188
# we definitely do not want to be doing this un-necessarily during
189189
# a FAST_RUN computation..
190-
if not (scipy.sparse.issparse(a) and scipy.sparse.issparse(b)):
190+
if not (issparse(a) and issparse(b)):
191191
return False
192192
diff = abs(a - b)
193193
if diff.nnz == 0:
@@ -203,14 +203,10 @@ def values_eq(self, a, b):
203203
# WARNING: equality comparison of sparse matrices is not fast or easy
204204
# we definitely do not want to be doing this un-necessarily during
205205
# a FAST_RUN computation..
206-
return (
207-
scipy.sparse.issparse(a)
208-
and scipy.sparse.issparse(b)
209-
and abs(a - b).sum() == 0.0
210-
)
206+
return issparse(a) and issparse(b) and abs(a - b).sum() == 0.0
211207

212208
def is_valid_value(self, a):
213-
return scipy.sparse.issparse(a) and (a.format == self.format)
209+
return issparse(a) and (a.format == self.format)
214210

215211
def get_shape_info(self, obj):
216212
obj = self.filter(obj)

0 commit comments

Comments
 (0)