Skip to content

Commit cedd595

Browse files
ricardoV94twiecki
authored andcommitted
Update xfailed invalid broadcast test
The test now failed due to improvement in static type shapes, but the underlying issue was not addressed. The logp for a broadcasted RV must consider that broadcasted dimensions are not independent
1 parent 7f3c032 commit cedd595

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc/tests/logprob/test_transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -778,16 +778,16 @@ def test_discrete_rv_multinary_transform_fails():
778778
joint_logprob({y_rv: y_rv.clone()})
779779

780780

781-
@pytest.mark.xfail(reason="Check not implemented yet, see #51")
781+
@pytest.mark.xfail(reason="Check not implemented yet")
782782
def test_invalid_broadcasted_transform_rv_fails():
783783
loc = at.vector("loc")
784-
y_rv = loc + at.random.normal(0, 1, size=2, name="base_rv")
784+
y_rv = loc + at.random.normal(0, 1, size=1, name="base_rv")
785785
y_rv.name = "y"
786786
y_vv = y_rv.clone()
787787

788-
logp = joint_logprob({y_rv: y_vv})
789-
logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]})
790-
assert False, "Should have failed before"
788+
# This logp derivation should fail or count only once the values that are broadcasted
789+
logp = joint_logprob({y_rv: y_vv}, sum=False)
790+
assert logp.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == ()
791791

792792

793793
@pytest.mark.parametrize("numerator", (1.0, 2.0))

0 commit comments

Comments
 (0)