@@ -29,12 +29,12 @@ def bench():
29
29
x @ y
30
30
31
31
32
- def elemwise_test_name (param ):
32
+ def id_of_test (param ):
33
33
side , rank = param
34
34
return f"{ side = } -{ rank = } "
35
35
36
36
37
- @pytest .fixture (params = itertools .product ([100 , 500 , 1000 ], [1 , 2 , 3 , 4 ]), ids = elemwise_test_name )
37
+ @pytest .fixture (params = itertools .product ([100 , 500 , 1000 ], [1 , 2 , 3 , 4 ]), ids = id_of_test )
38
38
def elemwise_args (request , seed , max_size ):
39
39
side , rank = request .param
40
40
if side ** rank >= max_size :
@@ -77,19 +77,21 @@ def bench():
77
77
f (x , y )
78
78
79
79
80
- @pytest .fixture (params = [100 , 500 , 1000 ], ids = side_ids )
80
+ @pytest .fixture (params = itertools . product ( [100 , 500 , 1000 ], [ 1 , 2 , 3 ]), ids = id_of_test )
81
81
def indexing_args (request , seed , max_size ):
82
- side = request .param
82
+ side , rank = request .param
83
83
if side ** 3 >= max_size :
84
84
pytest .skip ()
85
85
rng = np .random .default_rng (seed = seed )
86
+ shape = (side ,) * rank
87
+ x = sparse .random (shape , density = DENSITY , random_state = rng )
88
+ return x
86
89
87
- return sparse .random ((side , side , side ), density = DENSITY , random_state = rng )
88
90
89
- @pytest .mark .parametrize ("ndim" , [1 , 2 , 3 ])
90
- def test_index_scalar (benchmark , ndim , indexing_args ):
91
+ def test_index_scalar (benchmark , indexing_args ):
91
92
x = indexing_args
92
93
side = x .shape [0 ]
94
+ ndim = len (x .shape )
93
95
94
96
x [(side // 2 ,) * ndim ] # Numba compilation
95
97
@@ -101,41 +103,21 @@ def bench():
101
103
def test_index_slice (benchmark , indexing_args ):
102
104
x = indexing_args
103
105
side = x .shape [0 ]
106
+ rank = len (x .shape )
104
107
105
- x [: side // 2 ] # Numba compilation
106
-
107
- @benchmark
108
- def bench ():
109
- x [: side // 2 ]
110
-
111
-
112
- def test_index_slice2 (benchmark , indexing_args ):
113
- x = indexing_args
114
- side = x .shape [0 ]
115
-
116
- x [: side // 2 , : side // 2 ] # Numba compilation
117
-
118
- @benchmark
119
- def bench ():
120
- x [: side // 2 , : side // 2 ]
121
-
122
-
123
- def test_index_slice3 (benchmark , indexing_args ):
124
- x = indexing_args
125
- side = x .shape [0 ]
126
-
127
- x [: side // 2 , : side // 2 , : side // 2 ] # Numba compilation
108
+ x [(slice (side // 2 ),) * rank ] # Numba compilation
128
109
129
110
@benchmark
130
111
def bench ():
131
- x [: side // 2 , : side // 2 , : side // 2 ]
112
+ x [( slice ( side // 2 ),) * rank ]
132
113
133
114
134
115
def test_index_fancy (benchmark , indexing_args , seed ):
135
116
x = indexing_args
136
117
side = x .shape [0 ]
118
+ rank = len (x .shape )
137
119
rng = np .random .default_rng (seed = seed )
138
- index = rng .integers (0 , side , side // 2 )
120
+ index = rng .integers (( side // 2 ,) * rank )
139
121
140
122
x [index ] # Numba compilation
141
123
0 commit comments