Skip to content

Commit 2d753b3

Browse files
committed
allowing for minibatch of pytensor operations
1 parent 6b410f3 commit 2d753b3

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,11 @@ def test_allowed(self):
602602
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20)
603603
assert isinstance(mb.owner.op, MinibatchOp)
604604

605-
with pytest.raises(ValueError, match="not valid for Minibatch"):
606-
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
605+
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
606+
assert isinstance(mb.owner.op, MinibatchOp)
607607

608-
with pytest.raises(ValueError, match="not valid for Minibatch"):
609-
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
608+
for mb in pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20):
609+
assert isinstance(mb.owner.op, MinibatchOp)
610610

611611
def test_assert(self):
612612
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)

0 commit comments

Comments
 (0)