Skip to content

Commit 5588858

Browse files
danhphantwiecki
authored andcommitted
simplify serveral tests in TestMinibatch
1 parent cbf55ab commit 5588858

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

pymc/tests/test_minibatches.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -317,21 +317,18 @@ def test_2d(self):
317317
mb = pm.Minibatch(self.data, [(10, 42), (4, 42)])
318318
assert mb.eval().shape == (10, 4, 40, 10, 50)
319319

320-
def test_special1(self):
321-
mb = pm.Minibatch(self.data, [(10, 42), None, (4, 42)])
322-
assert mb.eval().shape == (10, 10, 4, 10, 50)
323-
324-
def test_special2(self):
325-
mb = pm.Minibatch(self.data, [(10, 42), Ellipsis, (4, 42)])
326-
assert mb.eval().shape == (10, 10, 40, 10, 4)
327-
328-
def test_special3(self):
329-
mb = pm.Minibatch(self.data, [(10, 42), None, Ellipsis, (4, 42)])
330-
assert mb.eval().shape == (10, 10, 40, 10, 4)
331-
332-
def test_special4(self):
333-
mb = pm.Minibatch(self.data, [10, None, Ellipsis, (4, 42)])
334-
assert mb.eval().shape == (10, 10, 40, 10, 4)
320+
@pytest.mark.parametrize(
321+
"batch_size, expected",
322+
[
323+
([(10, 42), None, (4, 42)], (10, 10, 4, 10, 50)),
324+
([(10, 42), Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
325+
([(10, 42), None, Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
326+
([10, None, Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
327+
],
328+
)
329+
def test_special_batch_size(self, batch_size, expected):
330+
mb = pm.Minibatch(self.data, batch_size)
331+
assert mb.eval().shape == expected
335332

336333
def test_cloning_available(self):
337334
gop = pm.Minibatch(np.arange(100), 1)

0 commit comments

Comments
 (0)