Skip to content

Commit 4b7d233

Browse files
Add sum_sites option to sum loss over sites or return as dict. (#1995)
* Add `sum_sites` option to sum loss over sites or return as `dict`. * Add missing newline in warning.
1 parent 2a46030 commit 4b7d233

File tree

2 files changed

+130
-52
lines changed

2 files changed

+130
-52
lines changed

numpyro/infer/elbo.py

Lines changed: 90 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from numpyro.handlers import replay, seed, substitute, trace
1919
from numpyro.infer.util import (
2020
_without_rsample_stop_gradient,
21+
compute_log_probs,
2122
get_importance_trace,
2223
is_identically_one,
23-
log_density,
2424
)
2525
from numpyro.ops.provenance import eval_provenance
2626
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level
@@ -148,12 +148,19 @@ class Trace_ELBO(ELBO):
148148
strategy, for example `jax.pmap`.
149149
:param multi_sample_guide: Whether to make an assumption that the guide proposes
150150
multiple samples.
151+
:param sum_sites: Whether to sum the ELBO contributions from all sites or return the
152+
contributions as a dictionary keyed by site.
151153
"""
152154

153155
def __init__(
154-
self, num_particles=1, vectorize_particles=True, multi_sample_guide=False
156+
self,
157+
num_particles: int = 1,
158+
vectorize_particles: bool = True,
159+
multi_sample_guide: bool = False,
160+
sum_sites: bool = True,
155161
):
156162
self.multi_sample_guide = multi_sample_guide
163+
self.sum_sites = sum_sites
157164
super().__init__(
158165
num_particles=num_particles, vectorize_particles=vectorize_particles
159166
)
@@ -171,7 +178,7 @@ def single_particle_elbo(rng_key):
171178
params = param_map.copy()
172179
model_seed, guide_seed = random.split(rng_key)
173180
seeded_guide = seed(guide, guide_seed)
174-
guide_log_density, guide_trace = log_density(
181+
guide_log_probs, guide_trace = compute_log_probs(
175182
seeded_guide, args, kwargs, param_map
176183
)
177184
mutable_params = {
@@ -187,13 +194,13 @@ def single_particle_elbo(rng_key):
187194
if site["type"] == "plate"
188195
}
189196

190-
def get_model_density(key, latent):
197+
def compute_model_log_probs(key, latent):
191198
with seed(rng_seed=key), substitute(data={**latent, **plates}):
192-
model_log_density, model_trace = log_density(
199+
model_log_probs, model_trace = compute_log_probs(
193200
model, args, kwargs, params
194201
)
195202
_validate_model(model_trace, plate_warning="loose")
196-
return model_log_density
203+
return model_log_probs
197204

198205
num_guide_samples = None
199206
for site in guide_trace.values():
@@ -209,15 +216,14 @@ def get_model_density(key, latent):
209216
if (site["type"] == "sample" and site["value"].size > 0)
210217
or (site["type"] == "deterministic")
211218
}
212-
model_log_density = vmap(get_model_density)(seeds, latents)
213-
assert model_log_density.ndim == 1
214-
model_log_density = model_log_density.sum(0)
215-
# log p(z) - log q(z)
216-
elbo_particle = (model_log_density - guide_log_density) / seeds.shape[0]
219+
model_log_probs = vmap(compute_model_log_probs)(seeds, latents)
220+
model_log_probs = jax.tree.map(
221+
lambda x: jnp.sum(x, axis=0), model_log_probs
222+
)
217223
else:
218224
seeded_model = seed(model, model_seed)
219225
replay_model = replay(seeded_model, guide_trace)
220-
model_log_density, model_trace = log_density(
226+
model_log_probs, model_trace = compute_log_probs(
221227
replay_model, args, kwargs, params
222228
)
223229
check_model_guide_match(model_trace, guide_trace)
@@ -229,31 +235,43 @@ def get_model_density(key, latent):
229235
if site["type"] == "mutable"
230236
}
231237
)
232-
# log p(z) - log q(z)
233-
elbo_particle = model_log_density - guide_log_density
238+
239+
# log p(z) - log q(z). We cannot use jax.tree.map(jnp.subtract, ...) because
240+
# there may be observed sites in `model_log_probs` that are not in
241+
# `guide_log_probs` and vice versa.
242+
union = set(model_log_probs).union(guide_log_probs)
243+
elbo_particle = {
244+
name: model_log_probs.get(name, 0.0) - guide_log_probs.get(name, 0.0)
245+
for name in union
246+
}
247+
if self.sum_sites:
248+
elbo_particle = sum(elbo_particle.values(), start=0.0)
234249

235250
if mutable_params:
236251
if self.num_particles == 1:
237252
return elbo_particle, mutable_params
238-
else:
239-
warnings.warn(
240-
"mutable state is currently ignored when num_particles > 1."
241-
)
242-
return elbo_particle, None
243-
else:
244-
return elbo_particle, None
253+
warnings.warn(
254+
"mutable state is currently ignored when num_particles > 1."
255+
)
256+
return elbo_particle, None
245257

246258
# Return (-elbo) since by convention we do gradient descent on a loss and
247259
# the ELBO is a lower bound that needs to be maximized.
248260
if self.num_particles == 1:
249261
elbo, mutable_state = single_particle_elbo(rng_key)
250-
return {"loss": -elbo, "mutable_state": mutable_state}
262+
return {
263+
"loss": jax.tree.map(jnp.negative, elbo),
264+
"mutable_state": mutable_state,
265+
}
251266
else:
252267
rng_keys = random.split(rng_key, self.num_particles)
253268
elbos, mutable_state = self.vectorize_particles_fn(
254269
single_particle_elbo, rng_keys
255270
)
256-
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}
271+
return {
272+
"loss": jax.tree.map(lambda x: -jnp.mean(x), elbos),
273+
"mutable_state": mutable_state,
274+
}
257275

258276

259277
def _get_log_prob_sum(site):
@@ -282,17 +300,15 @@ def _check_mean_field_requirement(model_trace, guide_trace):
282300
]
283301
assert set(model_sites) == set(guide_sites)
284302
if model_sites != guide_sites:
285-
(
286-
warnings.warn(
287-
"Failed to verify mean field restriction on the guide. "
288-
"To eliminate this warning, ensure model and guide sites "
289-
"occur in the same order.\n"
290-
+ "Model sites:\n "
291-
+ "\n ".join(model_sites)
292-
+ "Guide sites:\n "
293-
+ "\n ".join(guide_sites),
294-
stacklevel=find_stack_level(),
295-
),
303+
warnings.warn(
304+
"Failed to verify mean field restriction on the guide. "
305+
"To eliminate this warning, ensure model and guide sites "
306+
"occur in the same order.\n"
307+
+ "Model sites:\n "
308+
+ "\n ".join(model_sites)
309+
+ "\nGuide sites:\n "
310+
+ "\n ".join(guide_sites),
311+
stacklevel=find_stack_level(),
296312
)
297313

298314

@@ -302,6 +318,15 @@ class TraceMeanField_ELBO(ELBO):
302318
ELBO estimator in NumPyro that uses analytic KL divergences when those
303319
are available.
304320
321+
:param num_particles: The number of particles/samples used to form the ELBO
322+
(gradient) estimators.
323+
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
324+
num_particles-many particles in parallel. If False use `jax.lax.map`.
325+
Defaults to True. You can also pass a callable to specify a custom vectorization
326+
strategy, for example `jax.pmap`.
327+
:param sum_sites: Whether to sum the ELBO contributions from all sites or return the
328+
contributions as a dictionary keyed by site.
329+
305330
.. warning:: This estimator may give incorrect results if the mean-field
306331
condition is not satisfied.
307332
The mean field condition is a sufficient but not necessary condition for
@@ -314,6 +339,15 @@ class TraceMeanField_ELBO(ELBO):
314339
dependency structures.
315340
"""
316341

342+
def __init__(
343+
self,
344+
num_particles: int = 1,
345+
vectorize_particles: bool = True,
346+
sum_sites: bool = True,
347+
) -> None:
348+
self.sum_sites = sum_sites
349+
super().__init__(num_particles, vectorize_particles)
350+
317351
def loss_with_mutable_state(
318352
self, rng_key, param_map, model, guide, *args, **kwargs
319353
):
@@ -343,50 +377,54 @@ def single_particle_elbo(rng_key):
343377
_validate_model(model_trace, plate_warning="loose")
344378
_check_mean_field_requirement(model_trace, guide_trace)
345379

346-
elbo_particle = 0
380+
elbo_particle = {}
347381
for name, model_site in model_trace.items():
348382
if model_site["type"] == "sample":
349383
if model_site["is_observed"]:
350-
elbo_particle = elbo_particle + _get_log_prob_sum(model_site)
384+
elbo_particle[name] = _get_log_prob_sum(model_site)
351385
else:
352386
guide_site = guide_trace[name]
353387
try:
354388
kl_qp = kl_divergence(guide_site["fn"], model_site["fn"])
355389
kl_qp = scale_and_mask(kl_qp, scale=guide_site["scale"])
356-
elbo_particle = elbo_particle - jnp.sum(kl_qp)
390+
elbo_particle[name] = -jnp.sum(kl_qp)
357391
except NotImplementedError:
358-
elbo_particle = (
359-
elbo_particle
360-
+ _get_log_prob_sum(model_site)
361-
- _get_log_prob_sum(guide_site)
362-
)
392+
elbo_particle[name] = _get_log_prob_sum(
393+
model_site
394+
) - _get_log_prob_sum(guide_site)
363395

364396
# handle auxiliary sites in the guide
365397
for name, site in guide_trace.items():
366398
if site["type"] == "sample" and name not in model_trace:
367399
assert site["infer"].get("is_auxiliary") or site["is_observed"]
368-
elbo_particle = elbo_particle - _get_log_prob_sum(site)
400+
elbo_particle[name] = -_get_log_prob_sum(site)
401+
402+
if self.sum_sites:
403+
elbo_particle = sum(elbo_particle.values(), start=0.0)
369404

370405
if mutable_params:
371406
if self.num_particles == 1:
372407
return elbo_particle, mutable_params
373-
else:
374-
warnings.warn(
375-
"mutable state is currently ignored when num_particles > 1."
376-
)
377-
return elbo_particle, None
378-
else:
379-
return elbo_particle, None
408+
warnings.warn(
409+
"mutable state is currently ignored when num_particles > 1."
410+
)
411+
return elbo_particle, None
380412

381413
if self.num_particles == 1:
382414
elbo, mutable_state = single_particle_elbo(rng_key)
383-
return {"loss": -elbo, "mutable_state": mutable_state}
415+
return {
416+
"loss": jax.tree.map(jnp.negative, elbo),
417+
"mutable_state": mutable_state,
418+
}
384419
else:
385420
rng_keys = random.split(rng_key, self.num_particles)
386421
elbos, mutable_state = self.vectorize_particles_fn(
387422
single_particle_elbo, rng_keys
388423
)
389-
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}
424+
return {
425+
"loss": jax.tree.map(lambda x: -jnp.mean(x), elbos),
426+
"mutable_state": mutable_state,
427+
}
390428

391429

392430
class RenyiELBO(ELBO):

test/infer/test_svi.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,46 @@ def guide():
476476
svi.run(random.PRNGKey(0), 10)
477477

478478

479+
@pytest.mark.parametrize("loss_cls", [Trace_ELBO, TraceMeanField_ELBO])
480+
@pytest.mark.parametrize("sum_sites", [False, True])
481+
@pytest.mark.parametrize("num_particles", [1, 3])
482+
@pytest.mark.parametrize(
483+
"with_mutable", [False, True], ids=["with_mutable", "without_mutable"]
484+
)
485+
def test_elbo_by_site(loss_cls, sum_sites, num_particles, with_mutable):
486+
if num_particles > 1 and with_mutable:
487+
pytest.skip("Mutable state is currently ignored when num_particles > 1.")
488+
489+
def model():
490+
x = numpyro.sample("x", dist.Normal(-1, 1))
491+
numpyro.sample("y", dist.Gamma(3))
492+
493+
if with_mutable:
494+
numpyro_mutable("x1p", x + 1)
495+
496+
numpyro.sample("z", dist.Normal(x, 2), obs=5)
497+
498+
def guide():
499+
x = numpyro.sample("x", dist.Normal(2, 2))
500+
numpyro.sample("y", dist.LogNormal(0.1, 0.4))
501+
502+
if with_mutable:
503+
p = numpyro_mutable("x1p", {"value": None})
504+
p["value"] = x + 2
505+
506+
loss = loss_cls(num_particles=num_particles, sum_sites=sum_sites)
507+
key = random.key(9)
508+
value = loss.loss(key, {}, model, guide)
509+
if sum_sites:
510+
assert value.ndim == 0
511+
else:
512+
assert isinstance(value, dict) and set(value) == {"x", "y", "z"}
513+
total = sum(value.values())
514+
assert_allclose(
515+
total, loss_cls(num_particles).loss(key, {}, model, guide), rtol=1e-6
516+
)
517+
518+
479519
@pytest.mark.parametrize("stable_update", [True, False])
480520
@pytest.mark.parametrize("num_particles", [1, 10])
481521
@pytest.mark.parametrize("elbo", [Trace_ELBO, TraceMeanField_ELBO])

0 commit comments

Comments
 (0)