Skip to content

Commit eb1d8d6

Browse files
committed
Group rvs_to_value_vars tests in single class
1 parent fec9946 commit eb1d8d6

File tree

1 file changed

+103
-104
lines changed

1 file changed

+103
-104
lines changed

pymc/tests/test_aesaraf.py

Lines changed: 103 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -305,110 +305,6 @@ def test_walk_model():
305305
assert e in res
306306

307307

308-
@pytest.mark.parametrize("symbolic_rv", (False, True))
309-
@pytest.mark.parametrize("apply_transforms", (True, False))
310-
def test_rvs_to_value_vars(symbolic_rv, apply_transforms):
311-
312-
# Interval transform between last two arguments
313-
interval = Interval(bounds_fn=lambda *args: (args[-2], args[-1]))
314-
315-
with pm.Model() as m:
316-
a = pm.Uniform("a", 0.0, 1.0)
317-
if symbolic_rv:
318-
raw_b = pm.Uniform.dist(0, a + 1.0)
319-
b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval)
320-
# If not True, another distribution has to be used
321-
assert isinstance(b.owner.op, SymbolicRandomVariable)
322-
else:
323-
b = pm.Uniform("b", 0, a + 1.0, transform=interval)
324-
c = pm.Normal("c")
325-
d = at.log(c + b) + 2.0
326-
327-
a_value_var = m.rvs_to_values[a]
328-
assert a_value_var.tag.transform
329-
330-
b_value_var = m.rvs_to_values[b]
331-
c_value_var = m.rvs_to_values[c]
332-
333-
(res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms)
334-
335-
assert res.owner.op == at.add
336-
log_output = res.owner.inputs[0]
337-
assert log_output.owner.op == at.log
338-
log_add_output = res.owner.inputs[0].owner.inputs[0]
339-
assert log_add_output.owner.op == at.add
340-
c_output = log_add_output.owner.inputs[0]
341-
342-
# We make sure that the random variables were replaced
343-
# with their value variables
344-
assert c_output == c_value_var
345-
b_output = log_add_output.owner.inputs[1]
346-
# When transforms are applied, the input is the back-transformation of the value_var,
347-
# otherwise it is the value_var itself
348-
if apply_transforms:
349-
assert b_output != b_value_var
350-
else:
351-
assert b_output == b_value_var
352-
353-
res_ancestors = list(walk_model((res,)))
354-
res_rv_ancestors = [
355-
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
356-
]
357-
358-
# There shouldn't be any `RandomVariable`s in the resulting graph
359-
assert len(res_rv_ancestors) == 0
360-
assert b_value_var in res_ancestors
361-
assert c_value_var in res_ancestors
362-
# When transforms are used, `d` depends on `a` through the back-transformation of
363-
# `b`, otherwise there is no direct connection between `d` and `a`
364-
if apply_transforms:
365-
assert a_value_var in res_ancestors
366-
else:
367-
assert a_value_var not in res_ancestors
368-
369-
370-
def test_rvs_to_value_vars_nested():
371-
# Test that calling rvs_to_value_vars in models with nested transformations
372-
# does not change the original rvs in place. See issue #5172
373-
with pm.Model() as m:
374-
one = pm.LogNormal("one", mu=0)
375-
two = pm.LogNormal("two", mu=at.log(one))
376-
377-
# We add potentials or deterministics that are not in topological order
378-
pm.Potential("two_pot", two)
379-
pm.Potential("one_pot", one)
380-
381-
before = aesara.clone_replace(m.free_RVs)
382-
383-
# This call would change the model free_RVs in place in #5172
384-
res = rvs_to_value_vars(m.potentials, apply_transforms=True)
385-
386-
after = aesara.clone_replace(m.free_RVs)
387-
388-
assert equal_computations(before, after)
389-
390-
391-
def test_rvs_to_value_vars_unvalued_rv():
392-
with pm.Model() as m:
393-
x = pm.Normal("x")
394-
y = pm.Normal.dist(x)
395-
z = pm.Normal("z", y)
396-
out = z + y
397-
398-
x_value = m.rvs_to_values[x]
399-
z_value = m.rvs_to_values[z]
400-
401-
(res,) = rvs_to_value_vars((out,))
402-
403-
assert res.owner.op == at.add
404-
assert res.owner.inputs[0] is z_value
405-
res_y = res.owner.inputs[1]
406-
# Graph should have be cloned, and therefore y and res_y should have different ids
407-
assert res_y is not y
408-
assert res_y.owner.op == at.random.normal
409-
assert res_y.owner.inputs[3] is x_value
410-
411-
412308
class TestCompilePyMC:
413309
def test_check_bounds_flag(self):
414310
"""Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
@@ -633,3 +529,106 @@ def test_constant_fold_raises():
633529

634530
res = constant_fold((y, y.shape), raise_not_constant=False)
635531
assert tuple(res[1].eval()) == (5,)
532+
533+
534+
class TestReplaceRVsByValues:
535+
@pytest.mark.parametrize("symbolic_rv", (False, True))
536+
@pytest.mark.parametrize("apply_transforms", (True, False))
537+
def test_basic(self, symbolic_rv, apply_transforms):
538+
539+
# Interval transform between last two arguments
540+
interval = Interval(bounds_fn=lambda *args: (args[-2], args[-1]))
541+
542+
with pm.Model() as m:
543+
a = pm.Uniform("a", 0.0, 1.0)
544+
if symbolic_rv:
545+
raw_b = pm.Uniform.dist(0, a + 1.0)
546+
b = pm.Censored("b", raw_b, lower=0, upper=a + 1.0, transform=interval)
547+
# If not True, another distribution has to be used
548+
assert isinstance(b.owner.op, SymbolicRandomVariable)
549+
else:
550+
b = pm.Uniform("b", 0, a + 1.0, transform=interval)
551+
c = pm.Normal("c")
552+
d = at.log(c + b) + 2.0
553+
554+
a_value_var = m.rvs_to_values[a]
555+
assert a_value_var.tag.transform
556+
557+
b_value_var = m.rvs_to_values[b]
558+
c_value_var = m.rvs_to_values[c]
559+
560+
(res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms)
561+
562+
assert res.owner.op == at.add
563+
log_output = res.owner.inputs[0]
564+
assert log_output.owner.op == at.log
565+
log_add_output = res.owner.inputs[0].owner.inputs[0]
566+
assert log_add_output.owner.op == at.add
567+
c_output = log_add_output.owner.inputs[0]
568+
569+
# We make sure that the random variables were replaced
570+
# with their value variables
571+
assert c_output == c_value_var
572+
b_output = log_add_output.owner.inputs[1]
573+
# When transforms are applied, the input is the back-transformation of the value_var,
574+
# otherwise it is the value_var itself
575+
if apply_transforms:
576+
assert b_output != b_value_var
577+
else:
578+
assert b_output == b_value_var
579+
580+
res_ancestors = list(walk_model((res,)))
581+
res_rv_ancestors = [
582+
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
583+
]
584+
585+
# There shouldn't be any `RandomVariable`s in the resulting graph
586+
assert len(res_rv_ancestors) == 0
587+
assert b_value_var in res_ancestors
588+
assert c_value_var in res_ancestors
589+
# When transforms are used, `d` depends on `a` through the back-transformation of
590+
# `b`, otherwise there is no direct connection between `d` and `a`
591+
if apply_transforms:
592+
assert a_value_var in res_ancestors
593+
else:
594+
assert a_value_var not in res_ancestors
595+
596+
def test_unvalued_rv(self):
597+
with pm.Model() as m:
598+
x = pm.Normal("x")
599+
y = pm.Normal.dist(x)
600+
z = pm.Normal("z", y)
601+
out = z + y
602+
603+
x_value = m.rvs_to_values[x]
604+
z_value = m.rvs_to_values[z]
605+
606+
(res,) = rvs_to_value_vars((out,))
607+
608+
assert res.owner.op == at.add
609+
assert res.owner.inputs[0] is z_value
610+
res_y = res.owner.inputs[1]
611+
# Graph should have be cloned, and therefore y and res_y should have different ids
612+
assert res_y is not y
613+
assert res_y.owner.op == at.random.normal
614+
assert res_y.owner.inputs[3] is x_value
615+
616+
def test_no_change_inplace(self):
617+
# Test that calling rvs_to_value_vars in models with nested transformations
618+
# does not change the original rvs in place. See issue #5172
619+
with pm.Model() as m:
620+
one = pm.LogNormal("one", mu=0)
621+
two = pm.LogNormal("two", mu=at.log(one))
622+
623+
# We add potentials or deterministics that are not in topological order
624+
pm.Potential("two_pot", two)
625+
pm.Potential("one_pot", one)
626+
627+
before = aesara.clone_replace(m.free_RVs)
628+
629+
# This call would change the model free_RVs in place in #5172
630+
res = rvs_to_value_vars(m.potentials, apply_transforms=True)
631+
632+
after = aesara.clone_replace(m.free_RVs)
633+
634+
assert equal_computations(before, after)

0 commit comments

Comments
 (0)