Skip to content

Commit da224b3

Browse files
Numba sparse: Implement SparseDenseMultiply
Co-authored-by: Jesse Grabowski <[email protected]>
1 parent eacff38 commit da224b3

File tree

4 files changed

+130
-20
lines changed

4 files changed

+130
-20
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from pytensor.link.numba.dispatch.sparse import basic, variable
1+
from pytensor.link.numba.dispatch.sparse import basic, math, variable
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from pytensor.link.numba.dispatch import basic as numba_basic
2+
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
3+
from pytensor.sparse import SparseDenseMultiply, SparseDenseVectorMultiply
4+
5+
6+
@register_funcify_default_op_cache_key(SparseDenseMultiply)
7+
@register_funcify_default_op_cache_key(SparseDenseVectorMultiply)
8+
def numba_funcify_SparseDenseMultiply(op, node, **kwargs):
9+
x, y = node.inputs
10+
[z] = node.outputs
11+
out_dtype = z.type.dtype
12+
format = z.type.format
13+
same_dtype = x.type.dtype == out_dtype
14+
15+
if y.ndim == 0:
16+
17+
@numba_basic.numba_njit
18+
def sparse_multiply_scalar(x, y):
19+
if same_dtype:
20+
z = x.copy()
21+
else:
22+
z = x.astype(out_dtype)
23+
# Numba doesn't know how to handle in-place mutation / assignment of fields
24+
# z.data *= y
25+
z_data = z.data
26+
z_data *= y
27+
return z
28+
29+
return sparse_multiply_scalar
30+
31+
elif y.ndim == 1:
32+
33+
@numba_basic.numba_njit
34+
def sparse_dense_multiply(x, y):
35+
assert x.shape[1] == y.shape[0]
36+
if same_dtype:
37+
z = x.copy()
38+
else:
39+
z = x.astype(out_dtype)
40+
41+
M, N = x.shape
42+
indices = x.indices
43+
indptr = x.indptr
44+
z_data = z.data
45+
if format == "csc":
46+
for j in range(0, N):
47+
for i_idx in range(indptr[j], indptr[j + 1]):
48+
z_data[i_idx] *= y[j]
49+
return z
50+
51+
else:
52+
for i in range(0, M):
53+
for j_idx in range(indptr[i], indptr[i + 1]):
54+
j = indices[j_idx]
55+
z_data[j_idx] *= y[j]
56+
57+
return z
58+
59+
return sparse_dense_multiply
60+
61+
else: # y.ndim == 2
62+
63+
@numba_basic.numba_njit
64+
def sparse_dense_multiply(x, y):
65+
assert x.shape == y.shape
66+
if same_dtype:
67+
z = x.copy()
68+
else:
69+
z = x.astype(out_dtype)
70+
71+
M, N = x.shape
72+
indices = x.indices
73+
indptr = x.indptr
74+
z_data = z.data
75+
if format == "csc":
76+
for j in range(0, N):
77+
for i_idx in range(indptr[j], indptr[j + 1]):
78+
i = indices[i_idx]
79+
z_data[i_idx] *= y[i, j]
80+
return z
81+
82+
else:
83+
for i in range(0, M):
84+
for j_idx in range(indptr[i], indptr[i + 1]):
85+
j = indices[j_idx]
86+
z_data[j_idx] *= y[i, j]
87+
88+
return z
89+
90+
return sparse_dense_multiply

tests/link/numba/sparse/test_basic.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,12 @@ def test_sparse_constant(format, cache):
194194

195195
y_test = np.array([np.pi, np.e, np.euler_gamma])
196196
with config.change_flags(numba__cache=cache):
197-
with pytest.warns(
198-
UserWarning,
199-
match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method",
200-
):
201-
compare_numba_and_py_sparse(
202-
[y],
203-
[out],
204-
[y_test],
205-
eval_obj_mode=False,
206-
)
197+
compare_numba_and_py_sparse(
198+
[y],
199+
[out],
200+
[y_test],
201+
eval_obj_mode=False,
202+
)
207203

208204

209205
@pytest.mark.parametrize("format", ["csc", "csr"])
@@ -248,15 +244,11 @@ def test_simple_graph(format):
248244
x_test = sp.sparse.random(3, 3, density=0.5, format=format, random_state=rng)
249245
y_test = rng.normal(size=(3,))
250246

251-
with pytest.warns(
252-
UserWarning,
253-
match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method",
254-
):
255-
compare_numba_and_py_sparse(
256-
[x, y],
257-
z,
258-
[x_test, y_test],
259-
)
247+
compare_numba_and_py_sparse(
248+
[x, y],
249+
z,
250+
[x_test, y_test],
251+
)
260252

261253

262254
@pytest.mark.parametrize("format", ("csr", "csc"))
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import numpy as np
2+
import pytest
3+
import scipy
4+
5+
import pytensor.sparse as ps
6+
import pytensor.tensor as pt
7+
from tests.link.numba.sparse.test_basic import compare_numba_and_py_sparse
8+
9+
10+
pytestmark = pytest.mark.filterwarnings("error")
11+
12+
13+
@pytest.mark.parametrize("format", ["csr", "csc"])
14+
@pytest.mark.parametrize("y_ndim", [0, 1, 2])
15+
def test_sparse_dense_multiply(y_ndim, format):
16+
x = ps.matrix(format, name="x", shape=(3, 3))
17+
y = pt.tensor("y", shape=(3,) * y_ndim)
18+
z = x * y
19+
20+
rng = np.random.default_rng((155, y_ndim, format == "csr"))
21+
x_test = scipy.sparse.random(3, 3, density=0.5, format=format, random_state=rng)
22+
y_test = rng.normal(size=(3,) * y_ndim)
23+
24+
compare_numba_and_py_sparse(
25+
[x, y],
26+
z,
27+
[x_test, y_test],
28+
)

0 commit comments

Comments
 (0)