@@ -317,21 +317,18 @@ def test_2d(self):
317
317
mb = pm .Minibatch (self .data , [(10 , 42 ), (4 , 42 )])
318
318
assert mb .eval ().shape == (10 , 4 , 40 , 10 , 50 )
319
319
320
- def test_special1 (self ):
321
- mb = pm .Minibatch (self .data , [(10 , 42 ), None , (4 , 42 )])
322
- assert mb .eval ().shape == (10 , 10 , 4 , 10 , 50 )
323
-
324
- def test_special2 (self ):
325
- mb = pm .Minibatch (self .data , [(10 , 42 ), Ellipsis , (4 , 42 )])
326
- assert mb .eval ().shape == (10 , 10 , 40 , 10 , 4 )
327
-
328
- def test_special3 (self ):
329
- mb = pm .Minibatch (self .data , [(10 , 42 ), None , Ellipsis , (4 , 42 )])
330
- assert mb .eval ().shape == (10 , 10 , 40 , 10 , 4 )
331
-
332
- def test_special4 (self ):
333
- mb = pm .Minibatch (self .data , [10 , None , Ellipsis , (4 , 42 )])
334
- assert mb .eval ().shape == (10 , 10 , 40 , 10 , 4 )
320
+ @pytest .mark .parametrize (
321
+ "batch_size, expected" ,
322
+ [
323
+ ([(10 , 42 ), None , (4 , 42 )], (10 , 10 , 4 , 10 , 50 )),
324
+ ([(10 , 42 ), Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
325
+ ([(10 , 42 ), None , Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
326
+ ([10 , None , Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
327
+ ],
328
+ )
329
+ def test_special_batch_size (self , batch_size , expected ):
330
+ mb = pm .Minibatch (self .data , batch_size )
331
+ assert mb .eval ().shape == expected
335
332
336
333
def test_cloning_available (self ):
337
334
gop = pm .Minibatch (np .arange (100 ), 1 )
0 commit comments