@@ -15,8 +15,8 @@ def side_ids(side):
1515
1616
1717@pytest .mark .parametrize ("side" , [100 , 500 , 1000 ], ids = side_ids )
18- def test_matmul (benchmark , side , seed ):
19- if side ** 2 >= 2 ** 26 :
18+ def test_matmul (benchmark , side , seed , max_size ):
19+ if side ** 2 >= max_size :
2020 pytest .skip ()
2121 rng = np .random .default_rng (seed = seed )
2222 x = sparse .random ((side , side ), density = DENSITY , random_state = rng )
@@ -29,15 +29,15 @@ def bench():
2929 x @ y
3030
3131
32- def elemwise_test_name (param ):
32+ def get_test_id (param ):
3333 side , rank = param
3434 return f"{ side = } -{ rank = } "
3535
3636
37- @pytest .fixture (params = itertools .product ([100 , 500 , 1000 ], [1 , 2 , 3 , 4 ]), ids = elemwise_test_name )
38- def elemwise_args (request , seed ):
37+ @pytest .fixture (params = itertools .product ([100 , 500 , 1000 ], [1 , 2 , 3 , 4 ]), ids = get_test_id )
38+ def elemwise_args (request , seed , max_size ):
3939 side , rank = request .param
40- if side ** rank >= 2 ** 26 :
40+ if side ** rank >= max_size :
4141 pytest .skip ()
4242 rng = np .random .default_rng (seed = seed )
4343 shape = (side ,) * rank
@@ -57,9 +57,9 @@ def bench():
5757
5858
5959@pytest .fixture (params = [100 , 500 , 1000 ], ids = side_ids )
60- def elemwise_broadcast_args (request , seed ):
60+ def elemwise_broadcast_args (request , seed , max_size ):
6161 side = request .param
62- if side ** 2 >= 2 ** 26 :
62+ if side ** 2 >= max_size :
6363 pytest .skip ()
6464 rng = np .random .default_rng (seed = seed )
6565 x = sparse .random ((side , 1 , side ), density = DENSITY , random_state = rng )
@@ -77,65 +77,46 @@ def bench():
7777 f (x , y )
7878
7979
80- @pytest .fixture (params = [100 , 500 , 1000 ], ids = side_ids )
81- def indexing_args (request , seed ):
82- side = request .param
83- if side ** 3 >= 2 ** 26 :
80+ @pytest .fixture (params = itertools . product ( [100 , 500 , 1000 ], [ 1 , 2 , 3 ]), ids = get_test_id )
81+ def indexing_args (request , seed , max_size ):
82+ side , rank = request .param
83+ if side ** rank >= max_size :
8484 pytest .skip ()
8585 rng = np .random .default_rng (seed = seed )
86+ shape = (side ,) * rank
8687
87- return sparse .random (( side , side , side ) , density = DENSITY , random_state = rng )
88+ return sparse .random (shape , density = DENSITY , random_state = rng )
8889
8990
9091def test_index_scalar (benchmark , indexing_args ):
9192 x = indexing_args
9293 side = x .shape [0 ]
94+ rank = x .ndim
9395
94- x [side // 2 , side // 2 , side // 2 ] # Numba compilation
96+ x [( side // 2 ,) * rank ] # Numba compilation
9597
9698 @benchmark
9799 def bench ():
98- x [side // 2 , side // 2 , side // 2 ]
100+ x [( side // 2 ,) * rank ]
99101
100102
101103def test_index_slice (benchmark , indexing_args ):
102104 x = indexing_args
103105 side = x .shape [0 ]
106+ rank = x .ndim
104107
105- x [: side // 2 ] # Numba compilation
106-
107- @benchmark
108- def bench ():
109- x [: side // 2 ]
110-
111-
112- def test_index_slice2 (benchmark , indexing_args ):
113- x = indexing_args
114- side = x .shape [0 ]
115-
116- x [: side // 2 , : side // 2 ] # Numba compilation
117-
118- @benchmark
119- def bench ():
120- x [: side // 2 , : side // 2 ]
121-
122-
123- def test_index_slice3 (benchmark , indexing_args ):
124- x = indexing_args
125- side = x .shape [0 ]
126-
127- x [: side // 2 , : side // 2 , : side // 2 ] # Numba compilation
108+ x [(slice (side // 2 ),) * rank ] # Numba compilation
128109
129110 @benchmark
130111 def bench ():
131- x [: side // 2 , : side // 2 , : side // 2 ]
112+ x [( slice ( side // 2 ),) * rank ]
132113
133114
134115def test_index_fancy (benchmark , indexing_args , seed ):
135116 x = indexing_args
136117 side = x .shape [0 ]
137118 rng = np .random .default_rng (seed = seed )
138- index = rng .integers (0 , side , side // 2 )
119+ index = rng .integers (0 , side , size = ( side // 2 ,) )
139120
140121 x [index ] # Numba compilation
141122
0 commit comments