Skip to content

Commit 252f24c

Browse files
authored
Merge branch 'main' into spmv-example
2 parents 2f72240 + eb78737 commit 252f24c

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

examples/mttkrp_example.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import time
2+
3+
import sparse
4+
5+
import numpy as np
6+
7+
I_ = 1000
8+
J_ = 25
9+
K_ = 1000
10+
L_ = 100
11+
DENSITY = 0.0001
12+
ITERS = 3
13+
rng = np.random.default_rng(0)
14+
15+
16+
def benchmark(func, info, args):
17+
print(info)
18+
start = time.time()
19+
for _ in range(ITERS):
20+
func(*args)
21+
elapsed = time.time() - start
22+
print(f"Took {elapsed / ITERS} s.\n")
23+
24+
25+
if __name__ == "__main__":
26+
print("MTTKRP Example:\n")
27+
28+
B_sps = sparse.random((I_, K_, L_), density=DENSITY, random_state=rng) * 10
29+
D_sps = rng.random((L_, J_)) * 10
30+
C_sps = rng.random((K_, J_)) * 10
31+
32+
# Finch
33+
with sparse.Backend(backend=sparse.BackendType.Finch):
34+
B = sparse.asarray(B_sps.todense(), format="csf")
35+
D = sparse.asarray(np.array(D_sps, order="F"))
36+
C = sparse.asarray(np.array(C_sps, order="F"))
37+
38+
@sparse.compiled
39+
def mttkrp_finch(B, D, C):
40+
return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))
41+
42+
# Compile
43+
result_finch = mttkrp_finch(B, D, C)
44+
assert sparse.nonzero(result_finch)[0].size > 5
45+
# Benchmark
46+
benchmark(mttkrp_finch, info="Finch", args=[B, D, C])
47+
48+
# Numba
49+
with sparse.Backend(backend=sparse.BackendType.Numba):
50+
B = sparse.asarray(B_sps, format="gcxs")
51+
D = D_sps
52+
C = C_sps
53+
54+
def mttkrp_numba(B, D, C):
55+
return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))
56+
57+
# Compile
58+
result_numba = mttkrp_numba(B, D, C)
59+
# Benchmark
60+
benchmark(mttkrp_numba, info="Numba", args=[B, D, C])
61+
62+
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

examples/sddmm_example.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def benchmark(func, info, args):
2121

2222

2323
if __name__ == "__main__":
24+
print("SDDMM Example:\n")
25+
2426
a_sps = rng.random((LEN, LEN - 10)) * 10
2527
b_sps = rng.random((LEN - 10, LEN)) * 10
2628
s_sps = sps.random(LEN, LEN, format="coo", density=DENSITY, random_state=rng) * 10

0 commit comments

Comments
 (0)