2222
2323from pymc import Normal , draw
2424from pymc .data import Minibatch
25- from pymc .testing import select_by_precision
2625from pymc .variational .minibatch_rv import create_minibatch_rv
27- from tests .test_data import gen1 , gen2
2826
2927
3028class TestMinibatchRandomVariable :
@@ -42,50 +40,6 @@ def test_density_scaling(self):
4240 p2 = pytensor .function ([], model2 .logp ())
4341 assert p1 () * 2 == p2 ()
4442
45- def test_density_scaling_with_generator (self ):
46- # We have different size generators
47-
48- def true_dens ():
49- g = gen1 ()
50- for i , point in enumerate (g ):
51- yield st .norm .logpdf (point ).sum () * 10
52-
53- t = true_dens ()
54- # We have same size models
55- with pm .Model () as model1 :
56- pm .Normal ("n" , observed = gen1 (), total_size = 100 )
57- p1 = pytensor .function ([], model1 .logp ())
58-
59- with pm .Model () as model2 :
60- gen_var = pm .generator (gen2 ())
61- pm .Normal ("n" , observed = gen_var , total_size = 100 )
62- p2 = pytensor .function ([], model2 .logp ())
63-
64- for i in range (10 ):
65- _1 , _2 , _t = p1 (), p2 (), next (t )
66- decimals = select_by_precision (float64 = 7 , float32 = 1 )
67- np .testing .assert_almost_equal (_1 , _t , decimal = decimals ) # Value O(-50,000)
68- np .testing .assert_almost_equal (_1 , _2 )
69- # Done
70-
71- def test_gradient_with_scaling (self ):
72- with pm .Model () as model1 :
73- genvar = pm .generator (gen1 ())
74- m = pm .Normal ("m" )
75- pm .Normal ("n" , observed = genvar , total_size = 1000 )
76- grad1 = model1 .compile_fn (model1 .dlogp (vars = m ), point_fn = False )
77- with pm .Model () as model2 :
78- m = pm .Normal ("m" )
79- shavar = pytensor .shared (np .ones ((1000 , 100 )))
80- pm .Normal ("n" , observed = shavar )
81- grad2 = model2 .compile_fn (model2 .dlogp (vars = m ), point_fn = False )
82-
83- for i in range (10 ):
84- shavar .set_value (np .ones ((100 , 100 )) * i )
85- g1 = grad1 (1 )
86- g2 = grad2 (1 )
87- np .testing .assert_almost_equal (g1 , g2 )
88-
8943 def test_multidim_scaling (self ):
9044 with pm .Model () as model0 :
9145 pm .Normal ("n" , observed = [[1 , 1 ], [1 , 1 ]], total_size = [])
0 commit comments