18
18
from numpyro .handlers import replay , seed , substitute , trace
19
19
from numpyro .infer .util import (
20
20
_without_rsample_stop_gradient ,
21
+ compute_log_probs ,
21
22
get_importance_trace ,
22
23
is_identically_one ,
23
- log_density ,
24
24
)
25
25
from numpyro .ops .provenance import eval_provenance
26
26
from numpyro .util import _validate_model , check_model_guide_match , find_stack_level
@@ -148,12 +148,19 @@ class Trace_ELBO(ELBO):
148
148
strategy, for example `jax.pmap`.
149
149
:param multi_sample_guide: Whether to make an assumption that the guide proposes
150
150
multiple samples.
151
+ :param sum_sites: Whether to sum the ELBO contributions from all sites or return the
152
+ contributions as a dictionary keyed by site.
151
153
"""
152
154
153
155
def __init__ (
154
- self , num_particles = 1 , vectorize_particles = True , multi_sample_guide = False
156
+ self ,
157
+ num_particles : int = 1 ,
158
+ vectorize_particles : bool = True ,
159
+ multi_sample_guide : bool = False ,
160
+ sum_sites : bool = True ,
155
161
):
156
162
self .multi_sample_guide = multi_sample_guide
163
+ self .sum_sites = sum_sites
157
164
super ().__init__ (
158
165
num_particles = num_particles , vectorize_particles = vectorize_particles
159
166
)
@@ -171,7 +178,7 @@ def single_particle_elbo(rng_key):
171
178
params = param_map .copy ()
172
179
model_seed , guide_seed = random .split (rng_key )
173
180
seeded_guide = seed (guide , guide_seed )
174
- guide_log_density , guide_trace = log_density (
181
+ guide_log_probs , guide_trace = compute_log_probs (
175
182
seeded_guide , args , kwargs , param_map
176
183
)
177
184
mutable_params = {
@@ -187,13 +194,13 @@ def single_particle_elbo(rng_key):
187
194
if site ["type" ] == "plate"
188
195
}
189
196
190
- def get_model_density (key , latent ):
197
+ def compute_model_log_probs (key , latent ):
191
198
with seed (rng_seed = key ), substitute (data = {** latent , ** plates }):
192
- model_log_density , model_trace = log_density (
199
+ model_log_probs , model_trace = compute_log_probs (
193
200
model , args , kwargs , params
194
201
)
195
202
_validate_model (model_trace , plate_warning = "loose" )
196
- return model_log_density
203
+ return model_log_probs
197
204
198
205
num_guide_samples = None
199
206
for site in guide_trace .values ():
@@ -209,15 +216,14 @@ def get_model_density(key, latent):
209
216
if (site ["type" ] == "sample" and site ["value" ].size > 0 )
210
217
or (site ["type" ] == "deterministic" )
211
218
}
212
- model_log_density = vmap (get_model_density )(seeds , latents )
213
- assert model_log_density .ndim == 1
214
- model_log_density = model_log_density .sum (0 )
215
- # log p(z) - log q(z)
216
- elbo_particle = (model_log_density - guide_log_density ) / seeds .shape [0 ]
219
+ model_log_probs = vmap (compute_model_log_probs )(seeds , latents )
220
+ model_log_probs = jax .tree .map (
221
+ lambda x : jnp .sum (x , axis = 0 ), model_log_probs
222
+ )
217
223
else :
218
224
seeded_model = seed (model , model_seed )
219
225
replay_model = replay (seeded_model , guide_trace )
220
- model_log_density , model_trace = log_density (
226
+ model_log_probs , model_trace = compute_log_probs (
221
227
replay_model , args , kwargs , params
222
228
)
223
229
check_model_guide_match (model_trace , guide_trace )
@@ -229,31 +235,43 @@ def get_model_density(key, latent):
229
235
if site ["type" ] == "mutable"
230
236
}
231
237
)
232
- # log p(z) - log q(z)
233
- elbo_particle = model_log_density - guide_log_density
238
+
239
+ # log p(z) - log q(z). We cannot use jax.tree.map(jnp.subtract, ...) because
240
+ # there may be observed sites in `model_log_probs` that are not in
241
+ # `guide_log_probs` and vice versa.
242
+ union = set (model_log_probs ).union (guide_log_probs )
243
+ elbo_particle = {
244
+ name : model_log_probs .get (name , 0.0 ) - guide_log_probs .get (name , 0.0 )
245
+ for name in union
246
+ }
247
+ if self .sum_sites :
248
+ elbo_particle = sum (elbo_particle .values (), start = 0.0 )
234
249
235
250
if mutable_params :
236
251
if self .num_particles == 1 :
237
252
return elbo_particle , mutable_params
238
- else :
239
- warnings .warn (
240
- "mutable state is currently ignored when num_particles > 1."
241
- )
242
- return elbo_particle , None
243
- else :
244
- return elbo_particle , None
253
+ warnings .warn (
254
+ "mutable state is currently ignored when num_particles > 1."
255
+ )
256
+ return elbo_particle , None
245
257
246
258
# Return (-elbo) since by convention we do gradient descent on a loss and
247
259
# the ELBO is a lower bound that needs to be maximized.
248
260
if self .num_particles == 1 :
249
261
elbo , mutable_state = single_particle_elbo (rng_key )
250
- return {"loss" : - elbo , "mutable_state" : mutable_state }
262
+ return {
263
+ "loss" : jax .tree .map (jnp .negative , elbo ),
264
+ "mutable_state" : mutable_state ,
265
+ }
251
266
else :
252
267
rng_keys = random .split (rng_key , self .num_particles )
253
268
elbos , mutable_state = self .vectorize_particles_fn (
254
269
single_particle_elbo , rng_keys
255
270
)
256
- return {"loss" : - jnp .mean (elbos ), "mutable_state" : mutable_state }
271
+ return {
272
+ "loss" : jax .tree .map (lambda x : - jnp .mean (x ), elbos ),
273
+ "mutable_state" : mutable_state ,
274
+ }
257
275
258
276
259
277
def _get_log_prob_sum (site ):
@@ -282,17 +300,15 @@ def _check_mean_field_requirement(model_trace, guide_trace):
282
300
]
283
301
assert set (model_sites ) == set (guide_sites )
284
302
if model_sites != guide_sites :
285
- (
286
- warnings .warn (
287
- "Failed to verify mean field restriction on the guide. "
288
- "To eliminate this warning, ensure model and guide sites "
289
- "occur in the same order.\n "
290
- + "Model sites:\n "
291
- + "\n " .join (model_sites )
292
- + "Guide sites:\n "
293
- + "\n " .join (guide_sites ),
294
- stacklevel = find_stack_level (),
295
- ),
303
+ warnings .warn (
304
+ "Failed to verify mean field restriction on the guide. "
305
+ "To eliminate this warning, ensure model and guide sites "
306
+ "occur in the same order.\n "
307
+ + "Model sites:\n "
308
+ + "\n " .join (model_sites )
309
+ + "\n Guide sites:\n "
310
+ + "\n " .join (guide_sites ),
311
+ stacklevel = find_stack_level (),
296
312
)
297
313
298
314
@@ -302,6 +318,15 @@ class TraceMeanField_ELBO(ELBO):
302
318
ELBO estimator in NumPyro that uses analytic KL divergences when those
303
319
are available.
304
320
321
+ :param num_particles: The number of particles/samples used to form the ELBO
322
+ (gradient) estimators.
323
+ :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
324
+ num_particles-many particles in parallel. If False use `jax.lax.map`.
325
+ Defaults to True. You can also pass a callable to specify a custom vectorization
326
+ strategy, for example `jax.pmap`.
327
+ :param sum_sites: Whether to sum the ELBO contributions from all sites or return the
328
+ contributions as a dictionary keyed by site.
329
+
305
330
.. warning:: This estimator may give incorrect results if the mean-field
306
331
condition is not satisfied.
307
332
The mean field condition is a sufficient but not necessary condition for
@@ -314,6 +339,15 @@ class TraceMeanField_ELBO(ELBO):
314
339
dependency structures.
315
340
"""
316
341
342
+ def __init__ (
343
+ self ,
344
+ num_particles : int = 1 ,
345
+ vectorize_particles : bool = True ,
346
+ sum_sites : bool = True ,
347
+ ) -> None :
348
+ self .sum_sites = sum_sites
349
+ super ().__init__ (num_particles , vectorize_particles )
350
+
317
351
def loss_with_mutable_state (
318
352
self , rng_key , param_map , model , guide , * args , ** kwargs
319
353
):
@@ -343,50 +377,54 @@ def single_particle_elbo(rng_key):
343
377
_validate_model (model_trace , plate_warning = "loose" )
344
378
_check_mean_field_requirement (model_trace , guide_trace )
345
379
346
- elbo_particle = 0
380
+ elbo_particle = {}
347
381
for name , model_site in model_trace .items ():
348
382
if model_site ["type" ] == "sample" :
349
383
if model_site ["is_observed" ]:
350
- elbo_particle = elbo_particle + _get_log_prob_sum (model_site )
384
+ elbo_particle [ name ] = _get_log_prob_sum (model_site )
351
385
else :
352
386
guide_site = guide_trace [name ]
353
387
try :
354
388
kl_qp = kl_divergence (guide_site ["fn" ], model_site ["fn" ])
355
389
kl_qp = scale_and_mask (kl_qp , scale = guide_site ["scale" ])
356
- elbo_particle = elbo_particle - jnp .sum (kl_qp )
390
+ elbo_particle [ name ] = - jnp .sum (kl_qp )
357
391
except NotImplementedError :
358
- elbo_particle = (
359
- elbo_particle
360
- + _get_log_prob_sum (model_site )
361
- - _get_log_prob_sum (guide_site )
362
- )
392
+ elbo_particle [name ] = _get_log_prob_sum (
393
+ model_site
394
+ ) - _get_log_prob_sum (guide_site )
363
395
364
396
# handle auxiliary sites in the guide
365
397
for name , site in guide_trace .items ():
366
398
if site ["type" ] == "sample" and name not in model_trace :
367
399
assert site ["infer" ].get ("is_auxiliary" ) or site ["is_observed" ]
368
- elbo_particle = elbo_particle - _get_log_prob_sum (site )
400
+ elbo_particle [name ] = - _get_log_prob_sum (site )
401
+
402
+ if self .sum_sites :
403
+ elbo_particle = sum (elbo_particle .values (), start = 0.0 )
369
404
370
405
if mutable_params :
371
406
if self .num_particles == 1 :
372
407
return elbo_particle , mutable_params
373
- else :
374
- warnings .warn (
375
- "mutable state is currently ignored when num_particles > 1."
376
- )
377
- return elbo_particle , None
378
- else :
379
- return elbo_particle , None
408
+ warnings .warn (
409
+ "mutable state is currently ignored when num_particles > 1."
410
+ )
411
+ return elbo_particle , None
380
412
381
413
if self .num_particles == 1 :
382
414
elbo , mutable_state = single_particle_elbo (rng_key )
383
- return {"loss" : - elbo , "mutable_state" : mutable_state }
415
+ return {
416
+ "loss" : jax .tree .map (jnp .negative , elbo ),
417
+ "mutable_state" : mutable_state ,
418
+ }
384
419
else :
385
420
rng_keys = random .split (rng_key , self .num_particles )
386
421
elbos , mutable_state = self .vectorize_particles_fn (
387
422
single_particle_elbo , rng_keys
388
423
)
389
- return {"loss" : - jnp .mean (elbos ), "mutable_state" : mutable_state }
424
+ return {
425
+ "loss" : jax .tree .map (lambda x : - jnp .mean (x ), elbos ),
426
+ "mutable_state" : mutable_state ,
427
+ }
390
428
391
429
392
430
class RenyiELBO (ELBO ):
0 commit comments