Skip to content

Commit e23962a

Browse files
committed
remove the generator related tests
1 parent be837bf commit e23962a

File tree

1 file changed

+0
-46
lines changed

1 file changed

+0
-46
lines changed

tests/variational/test_minibatch_rv.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323
from pymc import Normal, draw
2424
from pymc.data import Minibatch
25-
from pymc.testing import select_by_precision
2625
from pymc.variational.minibatch_rv import create_minibatch_rv
27-
from tests.test_data import gen1, gen2
2826

2927

3028
class TestMinibatchRandomVariable:
@@ -42,50 +40,6 @@ def test_density_scaling(self):
4240
p2 = pytensor.function([], model2.logp())
4341
assert p1() * 2 == p2()
4442

45-
def test_density_scaling_with_generator(self):
46-
# We have different size generators
47-
48-
def true_dens():
49-
g = gen1()
50-
for i, point in enumerate(g):
51-
yield st.norm.logpdf(point).sum() * 10
52-
53-
t = true_dens()
54-
# We have same size models
55-
with pm.Model() as model1:
56-
pm.Normal("n", observed=gen1(), total_size=100)
57-
p1 = pytensor.function([], model1.logp())
58-
59-
with pm.Model() as model2:
60-
gen_var = pm.generator(gen2())
61-
pm.Normal("n", observed=gen_var, total_size=100)
62-
p2 = pytensor.function([], model2.logp())
63-
64-
for i in range(10):
65-
_1, _2, _t = p1(), p2(), next(t)
66-
decimals = select_by_precision(float64=7, float32=1)
67-
np.testing.assert_almost_equal(_1, _t, decimal=decimals) # Value O(-50,000)
68-
np.testing.assert_almost_equal(_1, _2)
69-
# Done
70-
71-
def test_gradient_with_scaling(self):
72-
with pm.Model() as model1:
73-
genvar = pm.generator(gen1())
74-
m = pm.Normal("m")
75-
pm.Normal("n", observed=genvar, total_size=1000)
76-
grad1 = model1.compile_fn(model1.dlogp(vars=m), point_fn=False)
77-
with pm.Model() as model2:
78-
m = pm.Normal("m")
79-
shavar = pytensor.shared(np.ones((1000, 100)))
80-
pm.Normal("n", observed=shavar)
81-
grad2 = model2.compile_fn(model2.dlogp(vars=m), point_fn=False)
82-
83-
for i in range(10):
84-
shavar.set_value(np.ones((100, 100)) * i)
85-
g1 = grad1(1)
86-
g2 = grad2(1)
87-
np.testing.assert_almost_equal(g1, g2)
88-
8943
def test_multidim_scaling(self):
9044
with pm.Model() as model0:
9145
pm.Normal("n", observed=[[1, 1], [1, 1]], total_size=[])

0 commit comments

Comments
 (0)