Skip to content

Commit 8f8719f

Browse files
committed
address expanded distribution
1 parent 831db77 commit 8f8719f

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

numpyro/handlers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
import jax.numpy as jnp
8484

8585
import numpyro
86-
from numpyro.distributions.distribution import COERCIONS
86+
from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution
8787
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate
8888
from numpyro.util import not_jax_tracer
8989

@@ -268,6 +268,8 @@ def process_message(self, msg):
268268
if msg["type"] == "sample":
269269
if msg["value"] is None:
270270
msg["value"] = msg["name"]
271+
if isinstance(msg["fn"], ExpandedDistribution):
272+
msg["fn"] = msg["fn"].base_dist
271273

272274
if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
273275
msg["stop"] = True

test/test_handlers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,6 @@ def model():
636636
x = numpyro.sample("x", dist.Normal(0, 1))
637637
with handlers.collapse():
638638
with handlers.plate("data", len(data)):
639-
# TODO: address expanded distribution
640639
y = numpyro.sample("y", dist.Normal(x, 1.))
641640
numpyro.sample("z", dist.Normal(y, 1.), obs=data)
642641

0 commit comments

Comments
 (0)