Skip to content

Commit a35e925

Browse files
authored
Add a warning for missing plate statements in the model (#1245)
* Fix missing plate statement in the model * Point users to format_shapes utility * Fix failing tests * Fix remaining failing tests * Fix further tests * Fix a bug at AutoNormal that makes TraceMeanField test fail
1 parent fa75d6d commit a35e925

File tree

7 files changed

+103
-46
lines changed

7 files changed

+103
-46
lines changed

notebooks/source/bad_posterior_geometry.ipynb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,13 @@
141141
"Instead of \n",
142142
"\n",
143143
"$$ \\beta \\sim {\\rm Normal}(0, \\lambda \\tau) $$\n",
144+
"\n",
144145
"we write\n",
146+
"\n",
145147
"$$ \\beta^\\prime \\sim {\\rm Normal}(0, 1) $$\n",
148+
"\n",
146149
"and\n",
150+
"\n",
147151
"$$ \\beta \\equiv \\lambda \\tau \\beta^\\prime $$\n",
148152
"\n",
149153
"where $\\beta$ is now defined *deterministically* in terms of $\\lambda$, $\\tau$,\n",

numpyro/infer/autoguide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def __call__(self, *args, **kwargs):
306306
site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
307307
if site["fn"].support is constraints.real or (
308308
isinstance(site["fn"].support, constraints.independent)
309-
and site["fn"].support is constraints.real
309+
and site["fn"].support.base_constraint is constraints.real
310310
):
311311
result[name] = numpyro.sample(name, site_fn)
312312
else:

numpyro/util.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,8 @@ def model(*args, **kwargs):
510510

511511
def check_model_guide_match(model_trace, guide_trace):
512512
"""
513-
:param dict model_trace: The model trace to check.
514-
:param dict guide_trace: The guide trace to check.
515-
:raises: RuntimeWarning, ValueError
516513
Checks the following assumptions:
514+
517515
1. Each sample site in the model also appears in the guide and is not
518516
marked auxiliary.
519517
2. Each sample site in the guide either appears in the model or is marked,
@@ -522,6 +520,10 @@ def check_model_guide_match(model_trace, guide_trace):
522520
appears in the model.
523521
4. At each sample site that appears in both the model and guide, the model
524522
and guide agree on sample shape.
523+
524+
:param dict model_trace: The model trace to check.
525+
:param dict guide_trace: The guide trace to check.
526+
:raises: RuntimeWarning, ValueError
525527
"""
526528
# Check ordinary sample sites.
527529
guide_vars = set(
@@ -606,6 +608,28 @@ def check_model_guide_match(model_trace, guide_trace):
606608
)
607609
)
608610

611+
# Check if plate is missing in the model.
612+
for name, site in model_trace.items():
613+
if site["type"] == "sample":
614+
value_ndim = jnp.ndim(site["value"])
615+
batch_shape = lax.broadcast_shapes(
616+
site["fn"].batch_shape,
617+
jnp.shape(site["value"])[: value_ndim - len(site["fn"].event_shape)],
618+
)
619+
plate_dims = set(f.dim for f in site["cond_indep_stack"])
620+
batch_ndim = len(batch_shape)
621+
for i in range(batch_ndim):
622+
dim = -i - 1
623+
if batch_shape[dim] > 1 and (dim not in plate_dims):
624+
# Skip checking if it is the `scan` dimension.
625+
if dim == -batch_ndim and site.get("_control_flow_done", False):
626+
continue
627+
warnings.warn(
628+
f"Missing a plate statement for batch dimension {dim}"
629+
f" at site '{name}'. You can use `numpyro.util.format_shapes`"
630+
" utility to check shapes at all sites of your model."
631+
)
632+
609633

610634
def _format_table(rows):
611635
"""

test/contrib/test_optim.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def test_beta_bernoulli(elbo):
6969

7070
def model(data):
7171
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
72-
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
72+
with numpyro.plate("N", len(data)):
73+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
7374

7475
def guide(data):
7576
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
@@ -100,7 +101,8 @@ def test_jitted_update_fn():
100101

101102
def model(data):
102103
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
103-
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
104+
with numpyro.plate("N", len(data)):
105+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
104106

105107
def guide(data):
106108
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)

test/infer/test_autoguide.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@
6969
)
7070
def test_beta_bernoulli(auto_class):
7171
data = jnp.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T
72+
N = len(data)
7273

7374
def model(data):
74-
f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)))
75-
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
75+
f = numpyro.sample("beta", dist.Beta(jnp.ones(2), jnp.ones(2)).to_event())
76+
with numpyro.plate("N", N):
77+
numpyro.sample("obs", dist.Bernoulli(f).to_event(1), obs=data)
7678

7779
adam = optim.Adam(0.01)
7880
if auto_class == AutoDAIS:
@@ -104,12 +106,12 @@ def body_fn(i, val):
104106
# Predictive can be instantiated from posterior samples...
105107
predictive = Predictive(model, posterior_samples=posterior_samples)
106108
predictive_samples = predictive(random.PRNGKey(1), None)
107-
assert predictive_samples["obs"].shape == (1000, 2)
109+
assert predictive_samples["obs"].shape == (1000, N, 2)
108110

109111
# ... or from the guide + params
110112
predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
111113
predictive_samples = predictive(random.PRNGKey(1), None)
112-
assert predictive_samples["obs"].shape == (1000, 2)
114+
assert predictive_samples["obs"].shape == (1000, N, 2)
113115

114116

115117
@pytest.mark.parametrize(
@@ -135,9 +137,10 @@ def test_logistic_regression(auto_class, Elbo):
135137
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
136138

137139
def model(data, labels):
138-
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
140+
coefs = numpyro.sample("coefs", dist.Normal(0, 1).expand([dim]).to_event())
139141
logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1))
140-
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
142+
with numpyro.plate("N", len(data)):
143+
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
141144

142145
adam = optim.Adam(0.01)
143146
rng_key_init = random.PRNGKey(1)
@@ -242,7 +245,8 @@ def model(data):
242245
dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)
243246
),
244247
)
245-
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
248+
with numpyro.plate("N", len(data)):
249+
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
246250

247251
adam = optim.Adam(0.01)
248252
rng_key_init = random.PRNGKey(1)
@@ -317,12 +321,14 @@ def actual_model(data):
317321
dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)
318322
),
319323
)
320-
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
324+
with numpyro.plate("N", len(data)):
325+
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
321326

322327
def expected_model(data):
323328
alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
324329
loc = numpyro.sample("loc", dist.Uniform(0, 1)) * alpha
325-
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
330+
with numpyro.plate("N", len(data)):
331+
numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
326332

327333
adam = optim.Adam(0.01)
328334
rng_key_init = random.PRNGKey(1)
@@ -355,9 +361,10 @@ def expected_model(data):
355361
def test_laplace_approximation_warning():
356362
def model(x, y):
357363
a = numpyro.sample("a", dist.Normal(0, 10))
358-
b = numpyro.sample("b", dist.Normal(0, 10), sample_shape=(3,))
364+
b = numpyro.sample("b", dist.Normal(0, 10).expand([3]).to_event())
359365
mu = a + b[0] * x + b[1] * x ** 2 + b[2] * x ** 3
360-
numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)
366+
with numpyro.plate("N", len(x)):
367+
numpyro.sample("y", dist.Normal(mu, 0.001), obs=y)
361368

362369
x = random.normal(random.PRNGKey(0), (3,))
363370
y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
@@ -375,7 +382,8 @@ def model(x, y):
375382
a = numpyro.sample("a", dist.Normal(0, 10))
376383
b = numpyro.sample("b", dist.Normal(0, 10))
377384
mu = a + b * x
378-
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
385+
with numpyro.plate("N", len(x)):
386+
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
379387

380388
x = random.normal(random.PRNGKey(0), (100,))
381389
y = 1 + 2 * x
@@ -401,7 +409,8 @@ def model(y):
401409
"sigma", dist.ImproperUniform(dist.constraints.positive, (), ())
402410
)
403411
mu = numpyro.deterministic("mu", lambda1 + lambda2)
404-
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
412+
with numpyro.plate("N", len(y)):
413+
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
405414

406415
guide = AutoDiagonalNormal(model)
407416
svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), y=y)
@@ -417,7 +426,8 @@ def model(x, y):
417426
nn = numpyro.module("nn", Dense(1), (10,))
418427
mu = nn(x).squeeze(-1)
419428
sigma = numpyro.sample("sigma", dist.HalfNormal(1))
420-
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
429+
with numpyro.plate("N", len(y)):
430+
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
421431

422432
guide = AutoDiagonalNormal(model)
423433
svi = SVI(model, guide, optim.Adam(0.003), Trace_ELBO(), x=x, y=y)
@@ -497,7 +507,8 @@ def model(y=None):
497507
mu = numpyro.sample("mu", dist.Normal(0, 5))
498508
sigma = numpyro.param("sigma", 1, constraint=constraints.positive)
499509

500-
y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y)
510+
with numpyro.plate("N", len(y)):
511+
y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y)
501512
numpyro.deterministic("z", (y - mu) / sigma)
502513

503514
mu, sigma = 2, 3

test/infer/test_svi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def test_beta_bernoulli(elbo, optimizer):
6363

6464
def model(data):
6565
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
66-
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
66+
with numpyro.plate("N", len(data)):
67+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
6768

6869
def guide(data):
6970
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
@@ -94,7 +95,8 @@ def test_run(progress_bar):
9495

9596
def model(data):
9697
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
97-
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
98+
with numpyro.plate("N", len(data)):
99+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
98100

99101
def guide(data):
100102
alpha_q = numpyro.param(
@@ -124,7 +126,8 @@ def test_jitted_update_fn():
124126

125127
def model(data):
126128
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
127-
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
129+
with numpyro.plate("N", len(data)):
130+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
128131

129132
def guide(data):
130133
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)

test/test_util.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -170,28 +170,27 @@ def model_test():
170170
)
171171

172172

173-
def test_check_model_guide_match():
174-
def _run_svi(model, guide):
173+
def _run_svi_check_warnings(model, guide, expected_string):
174+
with pytest.warns(UserWarning, match=expected_string) as ws:
175175
adam = numpyro.optim.Adam(1e-3)
176176
svi = numpyro.infer.SVI(model, guide, adam, numpyro.infer.Trace_ELBO())
177-
svi.run(random.PRNGKey(42), num_steps=50)
178-
179-
def _run_svi_check_warnings(model, guide, expected_string):
180-
with pytest.warns(UserWarning, match=expected_string) as ws:
181-
_run_svi(model, guide)
182-
assert len(ws) == 1
183-
assert expected_string in str(ws[0].message)
184-
185-
def _create_traces_check_error_string(model, guide, expected_string):
186-
model_trace = numpyro.handlers.trace(
187-
numpyro.handlers.seed(model, rng_seed=42)
188-
).get_trace()
189-
guide_trace = numpyro.handlers.trace(
190-
numpyro.handlers.seed(guide, rng_seed=42)
191-
).get_trace()
192-
with pytest.raises(ValueError, match=expected_string):
193-
check_model_guide_match(model_trace, guide_trace)
177+
svi.run(random.PRNGKey(42), num_steps=5)
178+
assert len(ws) == 1
179+
assert expected_string in str(ws[0].message)
180+
181+
182+
def _create_traces_check_error_string(model, guide, expected_string):
183+
model_trace = numpyro.handlers.trace(
184+
numpyro.handlers.seed(model, rng_seed=42)
185+
).get_trace()
186+
guide_trace = numpyro.handlers.trace(
187+
numpyro.handlers.seed(guide, rng_seed=42)
188+
).get_trace()
189+
with pytest.raises(ValueError, match=expected_string):
190+
check_model_guide_match(model_trace, guide_trace)
194191

192+
193+
def test_check_model_guide_match():
195194
# 1. Auxiliary vars in the model
196195
def model():
197196
numpyro.sample("x", dist.Normal())
@@ -236,7 +235,9 @@ def guide():
236235

237236
# 5. Check shapes agree
238237
def model():
239-
numpyro.sample("x", dist.Normal().expand((3, 2)))
238+
with numpyro.plate("a", 3, dim=-2):
239+
with numpyro.plate("b", 2, dim=-1):
240+
numpyro.sample("x", dist.Normal().expand((3, 2)))
240241

241242
def guide():
242243
numpyro.sample("x", dist.Normal().expand((3, 5)))
@@ -245,12 +246,24 @@ def guide():
245246

246247
# 6. Check subsample sites introduced by plate
247248
def model():
248-
numpyro.sample("x", dist.Normal().expand((10,)))
249+
with numpyro.plate("a", 10):
250+
numpyro.sample("x", dist.Normal().expand((10,)))
249251

250252
def guide():
251-
with numpyro.handlers.plate("data", 100, subsample_size=10):
253+
with numpyro.plate("data", 100, subsample_size=10):
252254
numpyro.sample("x", dist.Normal())
253255

254256
_run_svi_check_warnings(
255257
model, guide, "Found plate statements in guide but not model"
256258
)
259+
260+
261+
def test_missing_plate_in_model():
262+
def model():
263+
x = numpyro.sample("x", dist.Normal(0, 1))
264+
numpyro.sample("obs", dist.Normal(x, 1), obs=jnp.ones(10))
265+
266+
def guide():
267+
numpyro.sample("x", dist.Normal(0, 1))
268+
269+
_run_svi_check_warnings(model, guide, "Missing a plate statement")

0 commit comments

Comments
 (0)