@@ -339,6 +339,28 @@ def quantiles(self, params, quantiles):
339339 return result
340340
341341
342+ def _maybe_constrain_dist_for_site (
343+ site : dict , base_distribution : dist .Distribution
344+ ) -> dist .Distribution :
345+ support = site ["fn" ].support
346+
347+ # Short-circuit if the support is real and return the base distribution with the
348+ # correct number of event dimensions.
349+ base_support = support
350+ while isinstance (base_support , constraints .independent ):
351+ base_support = base_support .base_constraint
352+ if base_support is constraints .real :
353+ if support .event_dim :
354+ return base_distribution .to_event (support .event_dim )
355+ else :
356+ return base_distribution
357+
358+ # Transform the distribution to the support of the site.
359+ with helpful_support_errors (site ):
360+ transform = biject_to (support )
361+ return dist .TransformedDistribution (base_distribution , transform )
362+
363+
342364class AutoNormal (AutoGuide ):
343365 """
344366 This implementation of :class:`AutoGuide` uses Normal distributions
@@ -431,18 +453,11 @@ def __call__(self, *args, **kwargs):
431453 constraint = self .scale_constraint ,
432454 event_dim = event_dim ,
433455 )
434-
435- site_fn = dist .Normal (site_loc , site_scale ).to_event (event_dim )
436- if site ["fn" ].support is constraints .real or (
437- isinstance (site ["fn" ].support , constraints .independent )
438- and site ["fn" ].support .base_constraint is constraints .real
439- ):
440- result [name ] = numpyro .sample (name , site_fn )
441- else :
442- with helpful_support_errors (site ):
443- transform = biject_to (site ["fn" ].support )
444- guide_dist = dist .TransformedDistribution (site_fn , transform )
445- result [name ] = numpyro .sample (name , guide_dist )
456+ unconstrained_dist = dist .Normal (site_loc , site_scale )
457+ constrained_dist = _maybe_constrain_dist_for_site (
458+ site , unconstrained_dist
459+ )
460+ result [name ] = numpyro .sample (name , constrained_dist )
446461
447462 return result
448463
@@ -528,12 +543,6 @@ def __init__(
528543
529544 def _setup_prototype (self , * args , ** kwargs ):
530545 super ()._setup_prototype (* args , ** kwargs )
531- with numpyro .handlers .block ():
532- self ._init_locs = {
533- k : v
534- for k , v in self ._postprocess_fn (self ._init_locs ).items ()
535- if k in self ._init_locs
536- }
537546 for name , site in self .prototype_trace .items ():
538547 if site ["type" ] != "sample" or site ["is_observed" ]:
539548 continue
@@ -561,26 +570,22 @@ def __call__(self, *args, **kwargs):
561570 if site ["type" ] != "sample" or site ["is_observed" ]:
562571 continue
563572
564- event_dim = self ._event_dims [name ]
565573 init_loc = self ._init_locs [name ]
566574 with ExitStack () as stack :
567575 for frame in site ["cond_indep_stack" ]:
568576 stack .enter_context (plates [frame .name ])
569577
570- site_loc = numpyro .param (
571- "{}_{}_loc" .format (name , self .prefix ),
572- init_loc ,
573- constraint = site ["fn" ].support ,
574- event_dim = event_dim ,
578+ site_loc = numpyro .param (f"{ name } _{ self .prefix } _loc" , init_loc )
579+ unconstrained_dist = dist .Delta (site_loc )
580+ constrained_dist = _maybe_constrain_dist_for_site (
581+ site , unconstrained_dist
575582 )
576-
577- site_fn = dist .Delta (site_loc ).to_event (event_dim )
578- result [name ] = numpyro .sample (name , site_fn )
583+ result [name ] = numpyro .sample (name , constrained_dist )
579584
580585 return result
581586
582587 def sample_posterior (self , rng_key , params , * args , sample_shape = (), ** kwargs ):
583- locs = { k : params [ "{}_{}_loc" . format ( k , self .prefix )] for k in self . _init_locs }
588+ locs = self .median ( params )
584589 latent_samples = {
585590 k : jnp .broadcast_to (v , sample_shape + jnp .shape (v )) for k , v in locs .items ()
586591 }
@@ -600,7 +605,11 @@ def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
600605 return {** latent_samples , ** deterministic_samples }
601606
602607 def median (self , params ):
603- locs = {k : params ["{}_{}_loc" .format (k , self .prefix )] for k in self ._init_locs }
608+ locs = {}
609+ for name in self ._init_locs :
610+ unconstrained = params [f"{ name } _{ self .prefix } _loc" ]
611+ transform = biject_to (self .prototype_trace [name ]["fn" ].support )
612+ locs [name ] = transform (unconstrained )
604613 return locs
605614
606615
@@ -708,26 +717,11 @@ def __call__(self, *args, **kwargs):
708717
709718 # unpack continuous latent samples
710719 result = {}
711-
712720 for name , unconstrained_value in self ._unpack_latent (latent ).items ():
713721 site = self .prototype_trace [name ]
714- with helpful_support_errors (site ):
715- transform = biject_to (site ["fn" ].support )
716- value = transform (unconstrained_value )
717- event_ndim = site ["fn" ].event_dim
718- if numpyro .get_mask () is False :
719- log_density = 0.0
720- else :
721- log_density = - transform .log_abs_det_jacobian (
722- unconstrained_value , value
723- )
724- log_density = sum_rightmost (
725- log_density , jnp .ndim (log_density ) - jnp .ndim (value ) + event_ndim
726- )
727- delta_dist = dist .Delta (
728- value , log_density = log_density , event_dim = event_ndim
729- )
730- result [name ] = numpyro .sample (name , delta_dist )
722+ unconstrained_dist = dist .Delta (unconstrained_value )
723+ constrained_dist = _maybe_constrain_dist_for_site (site , unconstrained_dist )
724+ result [name ] = numpyro .sample (name , constrained_dist )
731725
732726 return result
733727
0 commit comments