@@ -113,15 +113,18 @@ def __call__(self, *args, **kwargs):
113113 raise NotImplementedError
114114
115115 @abstractmethod
116- def sample_posterior (self , rng_key , params , * args , ** kwargs ):
116+ def sample_posterior (self , rng_key , params , sample_shape = () ):
117117 """
118118 Generate samples from the approximate posterior over the latent
119119 sites in the model.
120120
121- :param jax.random.PRNGKey rng_key: PRNG seed.
122- :param params: Current parameters of model and autoguide.
123- :param sample_shape: (keyword argument) shape of samples to be drawn.
124- :return: batch of samples from the approximate posterior.
121+ :param jax.random.PRNGKey rng_key: random key to be used draw samples.
122+ :param dict params: Current parameters of model and autoguide.
123+ The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
124+ method from :class:`~numpyro.infer.svi.SVI`.
125+ :param tuple sample_shape: sample shape of each latent site, defaults to ().
126+ :return: a dict containing samples drawn the this guide.
127+ :rtype: dict
125128 """
126129 raise NotImplementedError
127130
@@ -161,14 +164,41 @@ def _setup_prototype(self, *args, **kwargs):
161164 elif site ["type" ] == "plate" :
162165 self ._prototype_frame_full_sizes [name ] = site ["args" ][0 ]
163166
167+ def median (self , params ):
168+ """
169+ Returns the posterior median value of each latent variable.
170+
171+ :param dict params: A dict containing parameter values.
172+ The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
173+ method from :class:`~numpyro.infer.svi.SVI`.
174+ :return: A dict mapping sample site name to median value.
175+ :rtype: dict
176+ """
177+ raise NotImplementedError
178+
179+ def quantiles (self , params , quantiles ):
180+ """
181+ Returns posterior quantiles each latent variable. Example::
182+
183+ print(guide.quantiles(params, [0.05, 0.5, 0.95]))
184+
185+ :param dict params: A dict containing parameter values.
186+ The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
187+ method from :class:`~numpyro.infer.svi.SVI`.
188+ :param list quantiles: A list of requested quantiles between 0 and 1.
189+ :return: A dict mapping sample site name to an array of quantile values.
190+ :rtype: dict
191+ """
192+ raise NotImplementedError
193+
164194
165195class AutoNormal (AutoGuide ):
166196 """
167197 This implementation of :class:`AutoGuide` uses Normal distributions
168198 to construct a guide over the entire latent space. The guide does not
169199 depend on the model's ``*args, **kwargs``.
170200
171- This should be equivalent to :class: `AutoDiagonalNormal` , but with
201+ This should be equivalent to :class:`AutoDiagonalNormal` , but with
172202 more convenient site names and with better support for mean field ELBO.
173203
174204 Usage::
@@ -231,12 +261,6 @@ def _setup_prototype(self, *args, **kwargs):
231261 )
232262
233263 def __call__ (self , * args , ** kwargs ):
234- """
235- An automatic guide with the same ``*args, **kwargs`` as the base ``model``.
236-
237- :return: A dict mapping sample site name to sampled value.
238- :rtype: dict
239- """
240264 if self .prototype_trace is None :
241265 # run model to inspect the model structure
242266 self ._setup_prototype (* args , ** kwargs )
@@ -314,10 +338,15 @@ def median(self, params):
314338 return self ._constrain (locs )
315339
316340 def quantiles (self , params , quantiles ):
317- quantiles = jnp .array (quantiles )[..., None ]
341+ quantiles = jnp .array (quantiles )
318342 locs = {k : params ["{}_{}_loc" .format (k , self .prefix )] for k in self ._init_locs }
319343 scales = {k : params ["{}_{}_scale" .format (k , self .prefix )] for k in locs }
320- latent = {k : dist .Normal (locs [k ], scales [k ]).icdf (quantiles ) for k in locs }
344+ latent = {
345+ k : dist .Normal (locs [k ], scales [k ]).icdf (
346+ quantiles .reshape ((- 1 ,) + (1 ,) * jnp .ndim (locs [k ]))
347+ )
348+ for k in locs
349+ }
321350 return self ._constrain (latent )
322351
323352
@@ -413,12 +442,6 @@ def sample_posterior(self, rng_key, params, sample_shape=()):
413442 return latent_samples
414443
415444 def median (self , params ):
416- """
417- Returns the posterior median value of each latent variable.
418-
419- :return: A dict mapping sample site name to median tensor.
420- :rtype: dict
421- """
422445 locs = {k : params ["{}_{}_loc" .format (k , self .prefix )] for k in self ._init_locs }
423446 return locs
424447
@@ -473,12 +496,6 @@ def _sample_latent(self, *args, **kwargs):
473496 )
474497
475498 def __call__ (self , * args , ** kwargs ):
476- """
477- An automatic guide with the same ``*args, **kwargs`` as the base ``model``.
478-
479- :return: A dict mapping sample site name to sampled value.
480- :rtype: dict
481- """
482499 if self .prototype_trace is None :
483500 # run model to inspect the model structure
484501 self ._setup_prototype (* args , ** kwargs )
@@ -585,49 +602,11 @@ def get_posterior(self, params):
585602 return dist .TransformedDistribution (base_dist , transform )
586603
587604 def sample_posterior (self , rng_key , params , sample_shape = ()):
588- """
589- Get samples from the learned posterior.
590-
591- :param jax.random.PRNGKey rng_key: random key to be used draw samples.
592- :param dict params: Current parameters of model and autoguide.
593- The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
594- method from :class:`~numpyro.infer.svi.SVI`.
595- :param tuple sample_shape: batch shape of each latent sample, defaults to ().
596- :return: a dict containing samples drawn the this guide.
597- :rtype: dict
598- """
599605 latent_sample = handlers .substitute (
600606 handlers .seed (self ._sample_latent , rng_key ), params
601607 )(sample_shape = sample_shape )
602608 return self ._unpack_and_constrain (latent_sample , params )
603609
604- def median (self , params ):
605- """
606- Returns the posterior median value of each latent variable.
607-
608- :param dict params: A dict containing parameter values.
609- The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
610- method from :class:`~numpyro.infer.svi.SVI`.
611- :return: A dict mapping sample site name to median tensor.
612- :rtype: dict
613- """
614- raise NotImplementedError
615-
616- def quantiles (self , params , quantiles ):
617- """
618- Returns posterior quantiles each latent variable. Example::
619-
620- print(guide.quantiles(opt_state, [0.05, 0.5, 0.95]))
621-
622- :param dict params: A dict containing parameter values.
623- The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
624- method from :class:`~numpyro.infer.svi.SVI`.
625- :param list quantiles: A list of requested quantiles between 0 and 1.
626- :return: A dict mapping sample site name to a list of quantile values.
627- :rtype: dict
628- """
629- raise NotImplementedError
630-
631610
632611class AutoDiagonalNormal (AutoContinuous ):
633612 """
0 commit comments