Skip to content

Commit 29f9632

Browse files
authored
test: Adapt benchmarks to use codspeed (#741)
1 parent 0612373 commit 29f9632

File tree

3 files changed

+119
-4
lines changed

3 files changed

+119
-4
lines changed

benchmarks/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import pytest
2+
3+
4+
@pytest.fixture
5+
def seed(scope="session"):
6+
return 42

benchmarks/test_benchmark_coo.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,38 @@
88
import numpy as np
99

1010
DENSITY = 0.01
11-
SEED = 42
11+
12+
13+
def side_ids(side):
14+
return f"{side=}"
15+
16+
17+
@pytest.mark.parametrize("side", [100, 500, 1000], ids=side_ids)
18+
def test_matmul(benchmark, side, seed):
19+
if side**2 >= 2**26:
20+
pytest.skip()
21+
rng = np.random.default_rng(seed=seed)
22+
x = sparse.random((side, side), density=DENSITY, random_state=rng)
23+
y = sparse.random((side, side), density=DENSITY, random_state=rng)
24+
25+
x @ y # Numba compilation
26+
27+
@benchmark
28+
def bench():
29+
x @ y
1230

1331

1432
def elemwise_test_name(param):
1533
side, rank = param
1634
return f"{side=}-{rank=}"
1735

1836

19-
@pytest.fixture(scope="module", params=itertools.product([100, 500, 1000], [1, 2, 3, 4]), ids=elemwise_test_name)
20-
def elemwise_args(request):
37+
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3, 4]), ids=elemwise_test_name)
38+
def elemwise_args(request, seed):
2139
side, rank = request.param
2240
if side**rank >= 2**26:
2341
pytest.skip()
24-
rng = np.random.default_rng(seed=SEED)
42+
rng = np.random.default_rng(seed=seed)
2543
shape = (side,) * rank
2644
x = sparse.random(shape, density=DENSITY, random_state=rng)
2745
y = sparse.random(shape, density=DENSITY, random_state=rng)
@@ -36,3 +54,91 @@ def test_elemwise(benchmark, f, elemwise_args):
3654
@benchmark
3755
def bench():
3856
f(x, y)
57+
58+
59+
@pytest.fixture(params=[100, 500, 1000], ids=side_ids)
60+
def elemwise_broadcast_args(request, seed):
61+
side = request.param
62+
if side**2 >= 2**26:
63+
pytest.skip()
64+
rng = np.random.default_rng(seed=seed)
65+
x = sparse.random((side, 1, side), density=DENSITY, random_state=rng)
66+
y = sparse.random((side, side), density=DENSITY, random_state=rng)
67+
return x, y
68+
69+
70+
@pytest.mark.parametrize("f", [operator.add, operator.mul])
71+
def test_elemwise_broadcast(benchmark, f, elemwise_broadcast_args):
72+
x, y = elemwise_broadcast_args
73+
f(x, y)
74+
75+
@benchmark
76+
def bench():
77+
f(x, y)
78+
79+
80+
@pytest.fixture(params=[100, 500, 1000], ids=side_ids)
81+
def indexing_args(request, seed):
82+
side = request.param
83+
if side**3 >= 2**26:
84+
pytest.skip()
85+
rng = np.random.default_rng(seed=seed)
86+
87+
return sparse.random((side, side, side), density=DENSITY, random_state=rng)
88+
89+
90+
def test_index_scalar(benchmark, indexing_args):
91+
x = indexing_args
92+
side = x.shape[0]
93+
94+
x[side // 2, side // 2, side // 2] # Numba compilation
95+
96+
@benchmark
97+
def bench():
98+
x[side // 2, side // 2, side // 2]
99+
100+
101+
def test_index_slice(benchmark, indexing_args):
102+
x = indexing_args
103+
side = x.shape[0]
104+
105+
x[: side // 2] # Numba compilation
106+
107+
@benchmark
108+
def bench():
109+
x[: side // 2]
110+
111+
112+
def test_index_slice2(benchmark, indexing_args):
113+
x = indexing_args
114+
side = x.shape[0]
115+
116+
x[: side // 2, : side // 2] # Numba compilation
117+
118+
@benchmark
119+
def bench():
120+
x[: side // 2, : side // 2]
121+
122+
123+
def test_index_slice3(benchmark, indexing_args):
124+
x = indexing_args
125+
side = x.shape[0]
126+
127+
x[: side // 2, : side // 2, : side // 2] # Numba compilation
128+
129+
@benchmark
130+
def bench():
131+
x[: side // 2, : side // 2, : side // 2]
132+
133+
134+
def test_index_fancy(benchmark, indexing_args, seed):
135+
x = indexing_args
136+
side = x.shape[0]
137+
rng = np.random.default_rng(seed=seed)
138+
index = rng.integers(0, side, side // 2)
139+
140+
x[index] # Numba compilation
141+
142+
@benchmark
143+
def bench():
144+
x[index]

benchmarks/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
CI_MODE = bool(int(os.getenv("CI_MODE", default="0")))

0 commit comments

Comments
 (0)