Skip to content

Commit c0a3778

Browse files
authored
Add elemwise examples (#697)
1 parent 4195756 commit c0a3778

File tree

6 files changed

+98
-42
lines changed

6 files changed

+98
-42
lines changed

examples/__init__.py

Whitespace-only changes.

examples/elemwise_example.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import importlib
2+
import operator
3+
import os
4+
5+
import sparse
6+
7+
from utils import benchmark
8+
9+
import numpy as np
10+
import scipy.sparse as sps
11+
12+
LEN = 10000
13+
DENSITY = 0.001
14+
ITERS = 3
15+
rng = np.random.default_rng(0)
16+
17+
18+
if __name__ == "__main__":
19+
print("Elementwise Example:\n")
20+
21+
for func_name in ["multiply", "add", "greater_equal"]:
22+
print(f"{func_name} benchmark:\n")
23+
24+
s1_sps = sps.random(LEN, LEN, format="csr", density=DENSITY, random_state=rng) * 10
25+
s1_sps.sum_duplicates()
26+
s2_sps = sps.random(LEN, LEN, format="csr", density=DENSITY, random_state=rng) * 10
27+
s2_sps.sum_duplicates()
28+
29+
# ======= Finch =======
30+
os.environ[sparse._ENV_VAR_NAME] = "Finch"
31+
importlib.reload(sparse)
32+
33+
s1 = sparse.asarray(s1_sps.asformat("csc"), format="csc")
34+
s2 = sparse.asarray(s2_sps.asformat("csc"), format="csc")
35+
36+
func = getattr(sparse, func_name)
37+
# Compile
38+
result_finch = func(s1, s2)
39+
# Benchmark
40+
benchmark(func, args=[s1, s2], info="Finch", iters=ITERS)
41+
42+
# ======= Numba =======
43+
os.environ[sparse._ENV_VAR_NAME] = "Numba"
44+
importlib.reload(sparse)
45+
46+
s1 = sparse.asarray(s1_sps)
47+
s2 = sparse.asarray(s2_sps)
48+
49+
func = getattr(sparse, func_name)
50+
# Compile
51+
result_numba = func(s1, s2)
52+
# Benchmark
53+
benchmark(func, args=[s1, s2], info="Numba", iters=ITERS)
54+
55+
# ======= SciPy =======
56+
s1 = s1_sps
57+
s2 = s2_sps
58+
59+
if func_name == "multiply":
60+
func, args = s1.multiply, [s2]
61+
elif func_name == "add":
62+
func, args = operator.add, [s1, s2]
63+
elif func_name == "greater_equal":
64+
func, args = operator.ge, [s1, s2]
65+
66+
result_scipy = func(*args)
67+
# Benchmark
68+
benchmark(func, args=args, info="SciPy", iters=ITERS)
69+
70+
np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
71+
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())
72+
np.testing.assert_allclose(result_finch.todense(), result_scipy.toarray())

examples/mttkrp_example.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import importlib
22
import os
3-
import time
43

54
import sparse
65

6+
from utils import benchmark
7+
78
import numpy as np
89

910
I_ = 1000
@@ -15,15 +16,6 @@
1516
rng = np.random.default_rng(0)
1617

1718

18-
def benchmark(func, info, args):
19-
print(info)
20-
start = time.time()
21-
for _ in range(ITERS):
22-
func(*args)
23-
elapsed = time.time() - start
24-
print(f"Took {elapsed / ITERS} s.\n")
25-
26-
2719
if __name__ == "__main__":
2820
print("MTTKRP Example:\n")
2921

@@ -45,9 +37,8 @@ def mttkrp_finch(B, D, C):
4537

4638
# Compile
4739
result_finch = mttkrp_finch(B, D, C)
48-
assert sparse.nonzero(result_finch)[0].size > 5
4940
# Benchmark
50-
benchmark(mttkrp_finch, info="Finch", args=[B, D, C])
41+
benchmark(mttkrp_finch, args=[B, D, C], info="Finch", iters=ITERS)
5142

5243
# ======= Numba =======
5344
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -63,6 +54,6 @@ def mttkrp_numba(B, D, C):
6354
# Compile
6455
result_numba = mttkrp_numba(B, D, C)
6556
# Benchmark
66-
benchmark(mttkrp_numba, info="Numba", args=[B, D, C])
57+
benchmark(mttkrp_numba, args=[B, D, C], info="Numba", iters=ITERS)
6758

6859
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

examples/sddmm_example.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import importlib
22
import os
3-
import time
43

54
import sparse
65

6+
from utils import benchmark
7+
78
import numpy as np
89
import scipy.sparse as sps
910

@@ -13,15 +14,6 @@
1314
rng = np.random.default_rng(0)
1415

1516

16-
def benchmark(func, info, args):
17-
print(info)
18-
start = time.time()
19-
for _ in range(ITERS):
20-
func(*args)
21-
elapsed = time.time() - start
22-
print(f"Took {elapsed / ITERS} s.\n")
23-
24-
2517
if __name__ == "__main__":
2618
print("SDDMM Example:\n")
2719

@@ -47,9 +39,8 @@ def sddmm_finch(s, a, b):
4739

4840
# Compile
4941
result_finch = sddmm_finch(s, a, b)
50-
assert sparse.nonzero(result_finch)[0].size > 5
5142
# Benchmark
52-
benchmark(sddmm_finch, info="Finch", args=[s, a, b])
43+
benchmark(sddmm_finch, args=[s, a, b], info="Finch", iters=ITERS)
5344

5445
# ======= Numba =======
5546
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -64,9 +55,8 @@ def sddmm_numba(s, a, b):
6455

6556
# Compile
6657
result_numba = sddmm_numba(s, a, b)
67-
assert sparse.nonzero(result_numba)[0].size > 5
6858
# Benchmark
69-
benchmark(sddmm_numba, info="Numba", args=[s, a, b])
59+
benchmark(sddmm_numba, args=[s, a, b], info="Numba", iters=ITERS)
7060

7161
# ======= SciPy =======
7262
def sddmm_scipy(s, a, b):
@@ -78,7 +68,7 @@ def sddmm_scipy(s, a, b):
7868

7969
result_scipy = sddmm_scipy(s, a, b)
8070
# Benchmark
81-
benchmark(sddmm_scipy, info="SciPy", args=[s, a, b])
71+
benchmark(sddmm_scipy, args=[s, a, b], info="SciPy", iters=ITERS)
8272

8373
np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
8474
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

examples/spmv_add_example.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import importlib
22
import os
3-
import time
43

54
import sparse
65

6+
from utils import benchmark
7+
78
import numpy as np
89
import scipy.sparse as sps
910

@@ -13,15 +14,6 @@
1314
rng = np.random.default_rng(0)
1415

1516

16-
def benchmark(func, info, args):
17-
print(info)
18-
start = time.time()
19-
for _ in range(ITERS):
20-
func(*args)
21-
elapsed = time.time() - start
22-
print(f"Took {elapsed / ITERS} s.\n")
23-
24-
2517
if __name__ == "__main__":
2618
print("SpMv_add Example:\n")
2719

@@ -43,9 +35,8 @@ def spmv_finch(A, x, y):
4335

4436
# Compile
4537
result_finch = spmv_finch(A, x, y)
46-
assert sparse.nonzero(result_finch)[0].size > 5
4738
# Benchmark
48-
benchmark(spmv_finch, info="Finch", args=[A, x, y])
39+
benchmark(spmv_finch, args=[A, x, y], info="Finch", iters=ITERS)
4940

5041
# ======= Numba =======
5142
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -62,7 +53,7 @@ def spmv_numba(A, x, y):
6253
result_numba = spmv_numba(A, x, y)
6354
assert sparse.nonzero(result_numba)[0].size > 5
6455
# Benchmark
65-
benchmark(spmv_numba, info="Numba", args=[A, x, y])
56+
benchmark(spmv_numba, args=[A, x, y], info="Numba", iters=ITERS)
6657

6758
# ======= SciPy =======
6859
def spmv_scipy(A, x, y):
@@ -74,7 +65,7 @@ def spmv_scipy(A, x, y):
7465

7566
result_scipy = spmv_scipy(A, x, y)
7667
# Benchmark
77-
benchmark(spmv_scipy, info="SciPy", args=[A, x, y])
68+
benchmark(spmv_scipy, args=[A, x, y], info="SciPy", iters=ITERS)
7869

7970
np.testing.assert_allclose(result_numba, result_scipy)
8071
np.testing.assert_allclose(result_finch.todense(), result_numba)

examples/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import time
2+
from collections.abc import Callable, Iterable
3+
from typing import Any
4+
5+
6+
def benchmark(func: Callable, args: Iterable[Any], info: str, iters: int):
7+
print(info)
8+
start = time.time()
9+
for _ in range(iters):
10+
func(*args)
11+
elapsed = time.time() - start
12+
print(f"Took {elapsed / iters} s.\n")

0 commit comments

Comments
 (0)