Skip to content

Commit 7906105

Browse files
committed
Apply review comments
1 parent 42486d2 commit 7906105

File tree

8 files changed

+46
-58
lines changed

8 files changed

+46
-58
lines changed

ci/test_examples.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
for example in $(find ./examples/ -iname *.py); do
2-
CI_MODE=True python $example
2+
CI_MODE=1 python $example
33
done

ci/test_notebooks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
CI_MODE=True pytest -n 4 --nbmake --nbmake-timeout=600 ./examples/*.ipynb
1+
CI_MODE=1 pytest -n 4 --nbmake --nbmake-timeout=600 ./examples/*.ipynb

examples/elemwise_example.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@
3434
s2 = sparse.asarray(s2_sps.asformat("csc"), format="csc")
3535

3636
func = getattr(sparse, func_name)
37-
# Compile
38-
result_finch = func(s1, s2)
39-
# Benchmark
40-
benchmark(func, args=[s1, s2], info="Finch", iters=ITERS)
37+
38+
# Compile & Benchmark
39+
result_finch = benchmark(func, args=[s1, s2], info="Finch", iters=ITERS)
4140

4241
# ======= Numba =======
4342
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -47,10 +46,9 @@
4746
s2 = sparse.asarray(s2_sps)
4847

4948
func = getattr(sparse, func_name)
50-
# Compile
51-
result_numba = func(s1, s2)
52-
# Benchmark
53-
benchmark(func, args=[s1, s2], info="Numba", iters=ITERS)
49+
50+
# Compile & Benchmark
51+
result_numba = benchmark(func, args=[s1, s2], info="Numba", iters=ITERS)
5452

5553
# ======= SciPy =======
5654
s1 = s1_sps
@@ -63,9 +61,8 @@
6361
elif func_name == "greater_equal":
6462
func, args = operator.ge, [s1, s2]
6563

66-
result_scipy = func(*args)
67-
# Benchmark
68-
benchmark(func, args=args, info="SciPy", iters=ITERS)
64+
# Compile & Benchmark
65+
result_scipy = benchmark(func, args=args, info="SciPy", iters=ITERS)
6966

7067
np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
7168
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

examples/matmul_example.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@
3333
def sddmm_finch(a, b):
3434
return a @ b
3535

36-
# Compile
37-
result_finch = sddmm_finch(a, b)
38-
# Benchmark
39-
benchmark(sddmm_finch, args=[a, b], info="Finch", iters=ITERS)
36+
# Compile & Benchmark
37+
result_finch = benchmark(sddmm_finch, args=[a, b], info="Finch", iters=ITERS)
4038

4139
# ======= Numba =======
4240
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -48,10 +46,8 @@ def sddmm_finch(a, b):
4846
def sddmm_numba(a, b):
4947
return a @ b
5048

51-
# Compile
52-
result_numba = sddmm_numba(a, b)
53-
# Benchmark
54-
benchmark(sddmm_numba, args=[a, b], info="Numba", iters=ITERS)
49+
# Compile & Benchmark
50+
result_numba = benchmark(sddmm_numba, args=[a, b], info="Numba", iters=ITERS)
5551

5652
# ======= SciPy =======
5753
def sddmm_scipy(a, b):
@@ -60,9 +56,8 @@ def sddmm_scipy(a, b):
6056
a = a_sps
6157
b = b_sps
6258

63-
result_scipy = sddmm_scipy(a, b)
64-
# Benchmark
65-
benchmark(sddmm_scipy, args=[a, b], info="SciPy", iters=ITERS)
59+
# Compile & Benchmark
60+
result_scipy = benchmark(sddmm_scipy, args=[a, b], info="SciPy", iters=ITERS)
6661

6762
# np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
6863
# np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

examples/mttkrp_example.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@
3535
def mttkrp_finch(B, D, C):
3636
return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))
3737

38-
# Compile
39-
result_finch = mttkrp_finch(B, D, C)
40-
# Benchmark
41-
benchmark(mttkrp_finch, args=[B, D, C], info="Finch", iters=ITERS)
38+
# Compile & Benchmark
39+
result_finch = benchmark(mttkrp_finch, args=[B, D, C], info="Finch", iters=ITERS)
4240

4341
# ======= Numba =======
4442
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -51,9 +49,7 @@ def mttkrp_finch(B, D, C):
5149
def mttkrp_numba(B, D, C):
5250
return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))
5351

54-
# Compile
55-
result_numba = mttkrp_numba(B, D, C)
56-
# Benchmark
57-
benchmark(mttkrp_numba, args=[B, D, C], info="Numba", iters=ITERS)
52+
# Compile & Benchmark
53+
result_numba = benchmark(mttkrp_numba, args=[B, D, C], info="Numba", iters=ITERS)
5854

5955
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

examples/sddmm_example.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,8 @@ def sddmm_finch(s, a, b):
3737
axis=-1,
3838
)
3939

40-
# Compile
41-
result_finch = sddmm_finch(s, a, b)
42-
# Benchmark
43-
benchmark(sddmm_finch, args=[s, a, b], info="Finch", iters=ITERS)
40+
# Compile & Benchmark
41+
result_finch = benchmark(sddmm_finch, args=[s, a, b], info="Finch", iters=ITERS)
4442

4543
# ======= Numba =======
4644
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -53,10 +51,8 @@ def sddmm_finch(s, a, b):
5351
def sddmm_numba(s, a, b):
5452
return s * (a @ b)
5553

56-
# Compile
57-
result_numba = sddmm_numba(s, a, b)
58-
# Benchmark
59-
benchmark(sddmm_numba, args=[s, a, b], info="Numba", iters=ITERS)
54+
# Compile & Benchmark
55+
result_numba = benchmark(sddmm_numba, args=[s, a, b], info="Numba", iters=ITERS)
6056

6157
# ======= SciPy =======
6258
def sddmm_scipy(s, a, b):
@@ -66,9 +62,8 @@ def sddmm_scipy(s, a, b):
6662
a = a_sps
6763
b = b_sps
6864

69-
result_scipy = sddmm_scipy(s, a, b)
70-
# Benchmark
71-
benchmark(sddmm_scipy, args=[s, a, b], info="SciPy", iters=ITERS)
65+
# Compile & Benchmark
66+
result_scipy = benchmark(sddmm_scipy, args=[s, a, b], info="SciPy", iters=ITERS)
7267

7368
np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
7469
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())

examples/spmv_add_example.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@
3333
def spmv_finch(A, x, y):
3434
return sparse.sum(A[:, None, :] * sparse.permute_dims(x, (1, 0))[None, :, :], axis=-1) + y
3535

36-
# Compile
37-
result_finch = spmv_finch(A, x, y)
38-
# Benchmark
39-
benchmark(spmv_finch, args=[A, x, y], info="Finch", iters=ITERS)
36+
# Compile & Benchmark
37+
result_finch = benchmark(spmv_finch, args=[A, x, y], info="Finch", iters=ITERS)
4038

4139
# ======= Numba =======
4240
os.environ[sparse._ENV_VAR_NAME] = "Numba"
@@ -49,11 +47,8 @@ def spmv_finch(A, x, y):
4947
def spmv_numba(A, x, y):
5048
return A @ x + y
5149

52-
# Compile
53-
result_numba = spmv_numba(A, x, y)
54-
assert sparse.nonzero(result_numba)[0].size > 5
55-
# Benchmark
56-
benchmark(spmv_numba, args=[A, x, y], info="Numba", iters=ITERS)
50+
# Compile & Benchmark
51+
result_numba = benchmark(spmv_numba, args=[A, x, y], info="Numba", iters=ITERS)
5752

5853
# ======= SciPy =======
5954
def spmv_scipy(A, x, y):
@@ -63,9 +58,8 @@ def spmv_scipy(A, x, y):
6358
x = x_sps
6459
y = y_sps
6560

66-
result_scipy = spmv_scipy(A, x, y)
67-
# Benchmark
68-
benchmark(spmv_scipy, args=[A, x, y], info="SciPy", iters=ITERS)
61+
# Compile & Benchmark
62+
result_scipy = benchmark(spmv_scipy, args=[A, x, y], info="SciPy", iters=ITERS)
6963

7064
np.testing.assert_allclose(result_numba, result_scipy)
7165
np.testing.assert_allclose(result_finch.todense(), result_numba)

examples/utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,28 @@
33
from collections.abc import Callable, Iterable
44
from typing import Any
55

6-
CI_MODE = os.getenv("CI_MODE", default=False)
6+
CI_MODE = bool(int(os.getenv("CI_MODE", default="0")))
77

88

9-
def benchmark(func: Callable, args: Iterable[Any], info: str, iters: int):
9+
def benchmark(
10+
func: Callable,
11+
args: Iterable[Any],
12+
info: str,
13+
iters: int,
14+
) -> object:
15+
# Compile
16+
result = func(*args)
17+
1018
if CI_MODE:
1119
print("CI mode - skipping benchmark")
12-
return
20+
return result
1321

22+
# Benchmark
1423
print(info)
1524
start = time.time()
1625
for _ in range(iters):
1726
func(*args)
1827
elapsed = time.time() - start
1928
print(f"Took {elapsed / iters} s.\n")
29+
30+
return result

0 commit comments

Comments
 (0)