Skip to content

Commit d4165a4

Browse files
committed
slice
1 parent 739f27f commit d4165a4

File tree

2 files changed

+16
-33
lines changed

2 files changed

+16
-33
lines changed

benchmarks/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
def seed(scope="session"):
66
return 42
77

8+
89
@pytest.fixture
910
def max_size(scope="session"):
10-
return 2**26
11+
return 2**26

benchmarks/test_benchmark_coo.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def bench():
2929
x @ y
3030

3131

32-
def elemwise_test_name(param):
32+
def id_of_test(param):
3333
side, rank = param
3434
return f"{side=}-{rank=}"
3535

3636

37-
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3, 4]), ids=elemwise_test_name)
37+
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3, 4]), ids=id_of_test)
3838
def elemwise_args(request, seed, max_size):
3939
side, rank = request.param
4040
if side**rank >= max_size:
@@ -77,19 +77,21 @@ def bench():
7777
f(x, y)
7878

7979

80-
@pytest.fixture(params=[100, 500, 1000], ids=side_ids)
80+
@pytest.fixture(params=itertools.product([100, 500, 1000], [1, 2, 3]), ids=id_of_test)
8181
def indexing_args(request, seed, max_size):
82-
side = request.param
82+
side, rank = request.param
8383
if side**3 >= max_size:
8484
pytest.skip()
8585
rng = np.random.default_rng(seed=seed)
86+
shape = (side,) * rank
87+
x = sparse.random(shape, density=DENSITY, random_state=rng)
88+
return x
8689

87-
return sparse.random((side, side, side), density=DENSITY, random_state=rng)
8890

89-
@pytest.mark.parametrize("ndim", [1, 2, 3])
90-
def test_index_scalar(benchmark, ndim, indexing_args):
91+
def test_index_scalar(benchmark, indexing_args):
9192
x = indexing_args
9293
side = x.shape[0]
94+
ndim = len(x.shape)
9395

9496
x[(side // 2,) * ndim] # Numba compilation
9597

@@ -101,41 +103,21 @@ def bench():
101103
def test_index_slice(benchmark, indexing_args):
102104
x = indexing_args
103105
side = x.shape[0]
106+
rank = len(x.shape)
104107

105-
x[: side // 2] # Numba compilation
106-
107-
@benchmark
108-
def bench():
109-
x[: side // 2]
110-
111-
112-
def test_index_slice2(benchmark, indexing_args):
113-
x = indexing_args
114-
side = x.shape[0]
115-
116-
x[: side // 2, : side // 2] # Numba compilation
117-
118-
@benchmark
119-
def bench():
120-
x[: side // 2, : side // 2]
121-
122-
123-
def test_index_slice3(benchmark, indexing_args):
124-
x = indexing_args
125-
side = x.shape[0]
126-
127-
x[: side // 2, : side // 2, : side // 2] # Numba compilation
108+
x[(slice(side // 2),) * rank] # Numba compilation
128109

129110
@benchmark
130111
def bench():
131-
x[: side // 2, : side // 2, : side // 2]
112+
x[(slice(side // 2),) * rank]
132113

133114

134115
def test_index_fancy(benchmark, indexing_args, seed):
135116
x = indexing_args
136117
side = x.shape[0]
118+
rank = len(x.shape)
137119
rng = np.random.default_rng(seed=seed)
138-
index = rng.integers(0, side, side // 2)
120+
index = rng.integers((side // 2,) * rank)
139121

140122
x[index] # Numba compilation
141123

0 commit comments

Comments
 (0)