27
27
28
28
import pymc as pm
29
29
30
+ from pymc .data import is_minibatch
30
31
from pymc .pytensorf import GeneratorOp , floatX
31
32
from pymc .tests .helpers import SeededTest , select_by_precision
32
33
@@ -696,15 +697,10 @@ def test_common_errors(self):
696
697
697
698
def test_mixed1 (self ):
698
699
with pm .Model ():
699
- data = np .random .rand (10 , 20 , 30 , 40 , 50 )
700
- mb = pm .Minibatch (data , [2 , None , 20 , Ellipsis , 10 ])
701
- pm .Normal ("n" , observed = mb , total_size = (10 , None , 30 , Ellipsis , 50 ))
702
-
703
- def test_mixed2 (self ):
704
- with pm .Model ():
705
- data = np .random .rand (10 , 20 , 30 , 40 , 50 )
706
- mb = pm .Minibatch (data , [2 , None , 20 ])
707
- pm .Normal ("n" , observed = mb , total_size = (10 , None , 30 ))
700
+ data = np .random .rand (10 , 20 )
701
+ mb = pm .Minibatch (data , batch_size = 5 )
702
+ v = pm .Normal ("n" , observed = mb , total_size = 10 )
703
+ assert pm .logp (v , 1 ) is not None , "Check index is allowed in graph"
708
704
709
705
def test_free_rv (self ):
710
706
with pm .Model () as model4 :
@@ -719,51 +715,28 @@ def test_free_rv(self):
719
715
720
716
@pytest .mark .usefixtures ("strict_float32" )
721
717
class TestMinibatch :
722
- data = np .random .rand (30 , 10 , 40 , 10 , 50 )
718
+ data = np .random .rand (30 , 10 )
723
719
724
720
def test_1d (self ):
725
- mb = pm .Minibatch (self .data , 20 )
726
- assert mb .eval ().shape == (20 , 10 , 40 , 10 , 50 )
727
-
728
- def test_2d (self ):
729
- mb = pm .Minibatch (self .data , [(10 , 42 ), (4 , 42 )])
730
- assert mb .eval ().shape == (10 , 4 , 40 , 10 , 50 )
731
-
732
- @pytest .mark .parametrize (
733
- "batch_size, expected" ,
734
- [
735
- ([(10 , 42 ), None , (4 , 42 )], (10 , 10 , 4 , 10 , 50 )),
736
- ([(10 , 42 ), Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
737
- ([(10 , 42 ), None , Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
738
- ([10 , None , Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
739
- ],
740
- )
741
- def test_special_batch_size (self , batch_size , expected ):
742
- mb = pm .Minibatch (self .data , batch_size )
743
- assert mb .eval ().shape == expected
744
-
745
- def test_cloning_available (self ):
746
- gop = pm .Minibatch (np .arange (100 ), 1 )
747
- res = gop ** 2
748
- shared = pytensor .shared (np .array ([10 ]))
749
- res1 = pytensor .clone_replace (res , {gop : shared })
750
- f = pytensor .function ([], res1 )
751
- assert f () == np .array ([100 ])
752
-
753
- def test_align (self ):
754
- m = pm .Minibatch (np .arange (1000 ), 1 , random_seed = 1 )
755
- n = pm .Minibatch (np .arange (1000 ), 1 , random_seed = 1 )
756
- f = pytensor .function ([], [m , n ])
757
- n .eval () # not aligned
758
- a , b = zip (* (f () for _ in range (1000 )))
759
- assert a != b
760
- pm .align_minibatches ()
761
- a , b = zip (* (f () for _ in range (1000 )))
762
- assert a == b
763
- n .eval () # not aligned
764
- pm .align_minibatches ([m ])
765
- a , b = zip (* (f () for _ in range (1000 )))
766
- assert a != b
767
- pm .align_minibatches ([m , n ])
768
- a , b = zip (* (f () for _ in range (1000 )))
769
- assert a == b
721
+ mb = pm .Minibatch (self .data , batch_size = 20 )
722
+ assert is_minibatch (mb )
723
+ assert mb .eval ().shape == (20 , 10 )
724
+
725
+ def test_allowed (self ):
726
+ mb = pm .Minibatch (at .as_tensor (self .data ).astype (int ), batch_size = 20 )
727
+ assert is_minibatch (mb )
728
+
729
+ def test_not_allowed (self ):
730
+ with pytest .raises (ValueError , match = "not valid for Minibatch" ):
731
+ mb = pm .Minibatch (at .as_tensor (self .data ) * 2 , batch_size = 20 )
732
+
733
+ def test_not_allowed2 (self ):
734
+ with pytest .raises (ValueError , match = "not valid for Minibatch" ):
735
+ mb = pm .Minibatch (self .data , at .as_tensor (self .data ) * 2 , batch_size = 20 )
736
+
737
+ def test_assert (self ):
738
+ with pytest .raises (
739
+ AssertionError , match = r"All variables shape\[0\] in Minibatch should be equal"
740
+ ):
741
+ d1 , d2 = pm .Minibatch (self .data , self .data [::2 ], batch_size = 20 )
742
+ d1 .eval ()
0 commit comments