@@ -15,8 +15,8 @@ def side_ids(side):
15
15
16
16
17
17
@pytest .mark .parametrize ("side" , [100 , 500 , 1000 ], ids = side_ids )
18
- def test_matmul (benchmark , side , seed ):
19
- if side ** 2 >= 2 ** 26 :
18
+ def test_matmul (benchmark , side , seed , max_size ):
19
+ if side ** 2 >= max_size :
20
20
pytest .skip ()
21
21
rng = np .random .default_rng (seed = seed )
22
22
x = sparse .random ((side , side ), density = DENSITY , random_state = rng )
@@ -35,9 +35,9 @@ def elemwise_test_name(param):
35
35
36
36
37
37
@pytest .fixture (params = itertools .product ([100 , 500 , 1000 ], [1 , 2 , 3 , 4 ]), ids = elemwise_test_name )
38
- def elemwise_args (request , seed ):
38
+ def elemwise_args (request , seed , max_size ):
39
39
side , rank = request .param
40
- if side ** rank >= 2 ** 26 :
40
+ if side ** rank >= max_size :
41
41
pytest .skip ()
42
42
rng = np .random .default_rng (seed = seed )
43
43
shape = (side ,) * rank
@@ -57,9 +57,9 @@ def bench():
57
57
58
58
59
59
@pytest .fixture (params = [100 , 500 , 1000 ], ids = side_ids )
60
- def elemwise_broadcast_args (request , seed ):
60
+ def elemwise_broadcast_args (request , seed , max_size ):
61
61
side = request .param
62
- if side ** 2 >= 2 ** 26 :
62
+ if side ** 2 >= max_size :
63
63
pytest .skip ()
64
64
rng = np .random .default_rng (seed = seed )
65
65
x = sparse .random ((side , 1 , side ), density = DENSITY , random_state = rng )
@@ -78,24 +78,24 @@ def bench():
78
78
79
79
80
80
@pytest .fixture (params = [100 , 500 , 1000 ], ids = side_ids )
81
- def indexing_args (request , seed ):
81
+ def indexing_args (request , seed , max_size ):
82
82
side = request .param
83
- if side ** 3 >= 2 ** 26 :
83
+ if side ** 3 >= max_size :
84
84
pytest .skip ()
85
85
rng = np .random .default_rng (seed = seed )
86
86
87
87
return sparse .random ((side , side , side ), density = DENSITY , random_state = rng )
88
88
89
-
90
- def test_index_scalar (benchmark , indexing_args ):
89
+ @ pytest . mark . parametrize ( "ndim" , [ 1 , 2 , 3 ])
90
+ def test_index_scalar (benchmark , ndim , indexing_args ):
91
91
x = indexing_args
92
92
side = x .shape [0 ]
93
93
94
- x [side // 2 , side // 2 , side // 2 ] # Numba compilation
94
+ x [( side // 2 ,) * ndim ] # Numba compilation
95
95
96
96
@benchmark
97
97
def bench ():
98
- x [side // 2 , side // 2 , side // 2 ]
98
+ x [( side // 2 ,) * ndim ]
99
99
100
100
101
101
def test_index_slice (benchmark , indexing_args ):
0 commit comments