Skip to content

Commit 05d37d6

Browse files
authored
Check plate consistent in auto guide (#1049)
* check plate consistent and add regression test * make format
1 parent 4876f87 commit 05d37d6

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

numpyro/infer/autoguide.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,12 @@ def _setup_prototype(self, *args, **kwargs):
148148
for name, site in self.prototype_trace.items():
149149
if site["type"] == "sample":
150150
for frame in site["cond_indep_stack"]:
151-
self._prototype_frames[frame.name] = frame
151+
if frame.name in self._prototype_frames:
152+
assert (
153+
frame == self._prototype_frames[frame.name]
154+
), f"The plate {frame.name} has inconsistent dim or size. Please check your model again."
155+
else:
156+
self._prototype_frames[frame.name] = frame
152157
elif site["type"] == "plate":
153158
self._prototype_frame_full_sizes[name] = site["args"][0]
154159

test/infer/test_autoguide.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,17 @@ def model(y=None):
482482
predictive_samples["z"],
483483
atol=0.05,
484484
)
485+
486+
487+
@pytest.mark.parametrize("size,dim", [(10, -2), (5, -1)])
488+
def test_plate_inconsistent(size, dim):
489+
def model():
490+
with numpyro.plate("a", 10, dim=-1):
491+
numpyro.sample("x", dist.Normal(0, 1))
492+
with numpyro.plate("a", size, dim=dim):
493+
numpyro.sample("y", dist.Normal(0, 1))
494+
495+
guide = AutoDelta(model)
496+
svi = SVI(model, guide, numpyro.optim.Adam(step_size=0.1), Trace_ELBO())
497+
with pytest.raises(AssertionError, match="has inconsistent dim or size"):
498+
svi.run(random.PRNGKey(0), 10)

0 commit comments

Comments
 (0)