Skip to content

Commit 847023e

Browse files
authored
Merge pull request #743 from DeaMariaLeon/bench3
test: Modified index and slice tests in order to vary ranks
2 parents 45cdc71 + ec4dc36 commit 847023e

File tree

2 files changed

+26
-40
lines changed

2 files changed

+26
-40
lines changed

benchmarks/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,8 @@
44
@pytest.fixture
55
def seed(scope="session"):
66
return 42
7+
8+
9+
@pytest.fixture
10+
def max_size(scope="session"):
11+
return 2**26

benchmarks/test_benchmark_coo.py

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def side_ids(side):
1515

1616

1717
@pytest.mark.parametrize("side", [100, 500, 1000], ids=side_ids)
18-
def test_matmul(benchmark, side, seed):
19-
if side**2 >= 2**26:
18+
def test_matmul(benchmark, side, seed, max_size):
19+
if side**2 >= max_size:
2020
pytest.skip()
2121
rng = np.random.default_rng(seed=seed)
2222
x = sparse.random((side, side), density=DENSITY, random_state=rng)
@@ -29,15 +29,15 @@ def bench():
2929
x @ y
3030

3131

32-
def elemwise_test_name(param):
32+
def get_test_id(param):
3333
side, rank = param
3434
return f"{side=}-{rank=}"
3535

3636

37-
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3, 4]), ids=elemwise_test_name)
38-
def elemwise_args(request, seed):
37+
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3, 4]), ids=get_test_id)
38+
def elemwise_args(request, seed, max_size):
3939
side, rank = request.param
40-
if side**rank >= 2**26:
40+
if side**rank >= max_size:
4141
pytest.skip()
4242
rng = np.random.default_rng(seed=seed)
4343
shape = (side,) * rank
@@ -57,9 +57,9 @@ def bench():
5757

5858

5959
@pytest.fixture(params=[100, 500, 1000], ids=side_ids)
60-
def elemwise_broadcast_args(request, seed):
60+
def elemwise_broadcast_args(request, seed, max_size):
6161
side = request.param
62-
if side**2 >= 2**26:
62+
if side**2 >= max_size:
6363
pytest.skip()
6464
rng = np.random.default_rng(seed=seed)
6565
x = sparse.random((side, 1, side), density=DENSITY, random_state=rng)
@@ -77,65 +77,46 @@ def bench():
7777
f(x, y)
7878

7979

80-
@pytest.fixture(params=[100, 500, 1000], ids=side_ids)
81-
def indexing_args(request, seed):
82-
side = request.param
83-
if side**3 >= 2**26:
80+
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3]), ids=get_test_id)
81+
def indexing_args(request, seed, max_size):
82+
side, rank = request.param
83+
if side**rank >= max_size:
8484
pytest.skip()
8585
rng = np.random.default_rng(seed=seed)
86+
shape = (side,) * rank
8687

87-
return sparse.random((side, side, side), density=DENSITY, random_state=rng)
88+
return sparse.random(shape, density=DENSITY, random_state=rng)
8889

8990

9091
def test_index_scalar(benchmark, indexing_args):
9192
x = indexing_args
9293
side = x.shape[0]
94+
rank = x.ndim
9395

94-
x[side // 2, side // 2, side // 2] # Numba compilation
96+
x[(side // 2,) * rank] # Numba compilation
9597

9698
@benchmark
9799
def bench():
98-
x[side // 2, side // 2, side // 2]
100+
x[(side // 2,) * rank]
99101

100102

101103
def test_index_slice(benchmark, indexing_args):
102104
x = indexing_args
103105
side = x.shape[0]
106+
rank = x.ndim
104107

105-
x[: side // 2] # Numba compilation
106-
107-
@benchmark
108-
def bench():
109-
x[: side // 2]
110-
111-
112-
def test_index_slice2(benchmark, indexing_args):
113-
x = indexing_args
114-
side = x.shape[0]
115-
116-
x[: side // 2, : side // 2] # Numba compilation
117-
118-
@benchmark
119-
def bench():
120-
x[: side // 2, : side // 2]
121-
122-
123-
def test_index_slice3(benchmark, indexing_args):
124-
x = indexing_args
125-
side = x.shape[0]
126-
127-
x[: side // 2, : side // 2, : side // 2] # Numba compilation
108+
x[(slice(side // 2),) * rank] # Numba compilation
128109

129110
@benchmark
130111
def bench():
131-
x[: side // 2, : side // 2, : side // 2]
112+
x[(slice(side // 2),) * rank]
132113

133114

134115
def test_index_fancy(benchmark, indexing_args, seed):
135116
x = indexing_args
136117
side = x.shape[0]
137118
rng = np.random.default_rng(seed=seed)
138-
index = rng.integers(0, side, side // 2)
119+
index = rng.integers(0, side, size=(side // 2,))
139120

140121
x[index] # Numba compilation
141122

0 commit comments

Comments
 (0)