Skip to content

Commit da3dd01

Browse files
committed
mark xfail for failing tests
1 parent ac09d78 commit da3dd01

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

test/test_handlers.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -649,25 +649,50 @@ def guide():
649649
svi.update(svi_state)
650650

651651

652+
@pytest.mark.xfail(reason="missing pattern in Funsor")
653+
def test_collapse_diag_normal_plate_normal():
654+
d = 3
655+
data = np.ones((5, d))
656+
657+
def model():
658+
x = numpyro.sample("x", dist.Normal(0, 1))
659+
with handlers.collapse():
660+
with handlers.plate("data", len(data)):
661+
y = numpyro.sample("y", dist.Normal(x, 1.).expand([d]).to_event(1))
662+
numpyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data)
663+
664+
def guide():
665+
loc = numpyro.param("loc", 0.)
666+
scale = numpyro.param("scale", 1., constraint=constraints.positive)
667+
numpyro.sample("x", dist.Normal(loc, scale))
668+
669+
svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
670+
svi_state = svi.init(random.PRNGKey(0))
671+
svi.update(svi_state)
672+
673+
674+
@pytest.mark.xfail(reason="missing pattern in Funsor")
652675
def test_collapse_normal_mvn_mvn():
653676
T, d, S = 5, 2, 3
654677
data = jnp.ones((T, S))
655678

656679
def model():
657-
x = numpyro.sample("x", dist.Exponential(1))
680+
x = numpyro.sample("x", dist.Normal(0, 1))
658681
with handlers.collapse():
659682
with numpyro.plate("d", d, dim=-1):
660-
beta0 = numpyro.sample("beta0", dist.Normal(0., 1.).expand([d, S]).to_event(1))
683+
beta0 = numpyro.sample("beta0", dist.Normal(x, 1.).expand([d, S]).to_event(1))
661684
beta = numpyro.sample(
662685
"beta", dist.MultivariateNormal(beta0, scale_tril=jnp.eye(S)))
663686

687+
# this fails because beta shape is (3,) while it should be (2, 3)
664688
mean = jnp.ones((T, d)) @ beta
665689
with numpyro.plate("data", T, dim=-1):
666690
numpyro.sample("obs", dist.MultivariateNormal(mean, scale_tril=jnp.eye(S)), obs=data)
667691

668692
def guide():
669-
rate = numpyro.param("rate", 1., constraint=constraints.positive)
670-
numpyro.sample("x", dist.Exponential(rate))
693+
loc = numpyro.param("loc", 0.)
694+
scale = numpyro.param("scale", 1., constraint=constraints.positive)
695+
numpyro.sample("x", dist.Normal(loc, scale))
671696

672697
svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
673698
svi_state = svi.init(random.PRNGKey(0))

0 commit comments

Comments
 (0)