Skip to content

Commit 739f27f

Browse files
committed
first try
1 parent 29f9632 commit 739f27f

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

benchmarks/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44
@pytest.fixture
55
def seed(scope="session"):
66
return 42
7+
8+
@pytest.fixture
9+
def max_size(scope="session"):
10+
return 2**26

benchmarks/test_benchmark_coo.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def side_ids(side):
1515

1616

1717
@pytest.mark.parametrize("side", [100, 500, 1000], ids=side_ids)
18-
def test_matmul(benchmark, side, seed):
19-
if side**2 >= 2**26:
18+
def test_matmul(benchmark, side, seed, max_size):
19+
if side**2 >= max_size:
2020
pytest.skip()
2121
rng = np.random.default_rng(seed=seed)
2222
x = sparse.random((side, side), density=DENSITY, random_state=rng)
@@ -35,9 +35,9 @@ def elemwise_test_name(param):
3535

3636

3737
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3, 4]), ids=elemwise_test_name)
38-
def elemwise_args(request, seed):
38+
def elemwise_args(request, seed, max_size):
3939
side, rank = request.param
40-
if side**rank >= 2**26:
40+
if side**rank >= max_size:
4141
pytest.skip()
4242
rng = np.random.default_rng(seed=seed)
4343
shape = (side,) * rank
@@ -57,9 +57,9 @@ def bench():
5757

5858

5959
@pytest.fixture(params=[100, 500, 1000], ids=side_ids)
60-
def elemwise_broadcast_args(request, seed):
60+
def elemwise_broadcast_args(request, seed, max_size):
6161
side = request.param
62-
if side**2 >= 2**26:
62+
if side**2 >= max_size:
6363
pytest.skip()
6464
rng = np.random.default_rng(seed=seed)
6565
x = sparse.random((side, 1, side), density=DENSITY, random_state=rng)
@@ -78,24 +78,24 @@ def bench():
7878

7979

8080
@pytest.fixture(params=[100, 500, 1000], ids=side_ids)
81-
def indexing_args(request, seed):
81+
def indexing_args(request, seed, max_size):
8282
side = request.param
83-
if side**3 >= 2**26:
83+
if side**3 >= max_size:
8484
pytest.skip()
8585
rng = np.random.default_rng(seed=seed)
8686

8787
return sparse.random((side, side, side), density=DENSITY, random_state=rng)
8888

89-
90-
def test_index_scalar(benchmark, indexing_args):
89+
@pytest.mark.parametrize("ndim", [1, 2, 3])
90+
def test_index_scalar(benchmark, ndim, indexing_args):
9191
x = indexing_args
9292
side = x.shape[0]
9393

94-
x[side // 2, side // 2, side // 2] # Numba compilation
94+
x[(side // 2,) * ndim] # Numba compilation
9595

9696
@benchmark
9797
def bench():
98-
x[side // 2, side // 2, side // 2]
98+
x[(side // 2,) * ndim]
9999

100100

101101
def test_index_slice(benchmark, indexing_args):

0 commit comments

Comments
 (0)