Skip to content

Commit a9ef0ea

Browse files
authored
Merge pull request #677 from pydata/spmv-example
Add `SpMV` example
2 parents eb78737 + 26118f7 commit a9ef0ea

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

examples/spmv_add_example.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import time
2+
3+
import sparse
4+
5+
import numpy as np
6+
import scipy.sparse as sps
7+
8+
LEN = 100000
9+
DENSITY = 0.000001
10+
ITERS = 3
11+
rng = np.random.default_rng(0)
12+
13+
14+
def benchmark(func, info, args):
15+
print(info)
16+
start = time.time()
17+
for _ in range(ITERS):
18+
func(*args)
19+
elapsed = time.time() - start
20+
print(f"Took {elapsed / ITERS} s.\n")
21+
22+
23+
if __name__ == "__main__":
24+
print("SpMv_add Example:\n")
25+
26+
A_sps = sps.random(LEN - 10, LEN, format="csc", density=DENSITY, random_state=rng) * 10
27+
x_sps = rng.random((LEN, 1)) * 10
28+
y_sps = rng.random((LEN - 10, 1)) * 10
29+
30+
# Finch
31+
with sparse.Backend(backend=sparse.BackendType.Finch):
32+
A = sparse.asarray(A_sps)
33+
x = sparse.asarray(np.array(x_sps, order="C"))
34+
y = sparse.asarray(np.array(y_sps, order="C"))
35+
36+
@sparse.compiled
37+
def spmv_finch(A, x, y):
38+
return sparse.sum(A[:, None, :] * sparse.permute_dims(x, (1, 0))[None, :, :], axis=-1) + y
39+
40+
# Compile
41+
result_finch = spmv_finch(A, x, y)
42+
assert sparse.nonzero(result_finch)[0].size > 5
43+
# Benchmark
44+
benchmark(spmv_finch, info="Finch", args=[A, x, y])
45+
46+
# Numba
47+
with sparse.Backend(backend=sparse.BackendType.Numba):
48+
A = sparse.asarray(A_sps, format="csc")
49+
x = x_sps
50+
y = y_sps
51+
52+
def spmv_numba(A, x, y):
53+
return A @ x + y
54+
55+
# Compile
56+
result_numba = spmv_numba(A, x, y)
57+
assert sparse.nonzero(result_numba)[0].size > 5
58+
# Benchmark
59+
benchmark(spmv_numba, info="Numba", args=[A, x, y])
60+
61+
# SciPy
62+
def spmv_scipy(A, x, y):
63+
return A @ x + y
64+
65+
A = A_sps
66+
x = x_sps
67+
y = y_sps
68+
69+
result_scipy = spmv_scipy(A, x, y)
70+
# Benchmark
71+
benchmark(spmv_scipy, info="SciPy", args=[A, x, y])
72+
73+
np.testing.assert_allclose(result_numba, result_scipy)
74+
np.testing.assert_allclose(result_finch.todense(), result_numba)
75+
np.testing.assert_allclose(result_finch.todense(), result_scipy)

0 commit comments

Comments
 (0)