Skip to content

Commit 77be4d8

Browse files
authored
Fix AutoNormal.quantiles (#1066)
* fix AutoNormal.quantiles and add docs * force tfp release less than 0.14.0.dev * rearange docs * fix docs
1 parent 0d103ba commit 77be4d8

File tree

3 files changed

+56
-65
lines changed

3 files changed

+56
-65
lines changed

docs/source/autoguide.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ Automatic Guide Generation
33

44
.. automodule:: numpyro.infer.autoguide
55

6+
AutoGuide
7+
---------
8+
.. autoclass:: numpyro.infer.autoguide.AutoGuide
9+
:members:
10+
:undoc-members:
11+
:show-inheritance:
12+
:member-order: bysource
13+
614
AutoContinuous
715
--------------
816
.. autoclass:: numpyro.infer.autoguide.AutoContinuous
@@ -73,4 +81,4 @@ AutoDelta
7381
:members:
7482
:undoc-members:
7583
:show-inheritance:
76-
:member-order: bysource
84+
:member-order: bysource

numpyro/infer/autoguide.py

Lines changed: 43 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,18 @@ def __call__(self, *args, **kwargs):
113113
raise NotImplementedError
114114

115115
@abstractmethod
116-
def sample_posterior(self, rng_key, params, *args, **kwargs):
116+
def sample_posterior(self, rng_key, params, sample_shape=()):
117117
"""
118118
Generate samples from the approximate posterior over the latent
119119
sites in the model.
120120
121-
:param jax.random.PRNGKey rng_key: PRNG seed.
122-
:param params: Current parameters of model and autoguide.
123-
:param sample_shape: (keyword argument) shape of samples to be drawn.
124-
:return: batch of samples from the approximate posterior.
121+
:param jax.random.PRNGKey rng_key: random key to be used draw samples.
122+
:param dict params: Current parameters of model and autoguide.
123+
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
124+
method from :class:`~numpyro.infer.svi.SVI`.
125+
:param tuple sample_shape: sample shape of each latent site, defaults to ().
126+
:return: a dict containing samples drawn the this guide.
127+
:rtype: dict
125128
"""
126129
raise NotImplementedError
127130

@@ -161,14 +164,41 @@ def _setup_prototype(self, *args, **kwargs):
161164
elif site["type"] == "plate":
162165
self._prototype_frame_full_sizes[name] = site["args"][0]
163166

167+
def median(self, params):
168+
"""
169+
Returns the posterior median value of each latent variable.
170+
171+
:param dict params: A dict containing parameter values.
172+
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
173+
method from :class:`~numpyro.infer.svi.SVI`.
174+
:return: A dict mapping sample site name to median value.
175+
:rtype: dict
176+
"""
177+
raise NotImplementedError
178+
179+
def quantiles(self, params, quantiles):
180+
"""
181+
Returns posterior quantiles each latent variable. Example::
182+
183+
print(guide.quantiles(params, [0.05, 0.5, 0.95]))
184+
185+
:param dict params: A dict containing parameter values.
186+
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
187+
method from :class:`~numpyro.infer.svi.SVI`.
188+
:param list quantiles: A list of requested quantiles between 0 and 1.
189+
:return: A dict mapping sample site name to an array of quantile values.
190+
:rtype: dict
191+
"""
192+
raise NotImplementedError
193+
164194

165195
class AutoNormal(AutoGuide):
166196
"""
167197
This implementation of :class:`AutoGuide` uses Normal distributions
168198
to construct a guide over the entire latent space. The guide does not
169199
depend on the model's ``*args, **kwargs``.
170200
171-
This should be equivalent to :class: `AutoDiagonalNormal` , but with
201+
This should be equivalent to :class:`AutoDiagonalNormal` , but with
172202
more convenient site names and with better support for mean field ELBO.
173203
174204
Usage::
@@ -231,12 +261,6 @@ def _setup_prototype(self, *args, **kwargs):
231261
)
232262

233263
def __call__(self, *args, **kwargs):
234-
"""
235-
An automatic guide with the same ``*args, **kwargs`` as the base ``model``.
236-
237-
:return: A dict mapping sample site name to sampled value.
238-
:rtype: dict
239-
"""
240264
if self.prototype_trace is None:
241265
# run model to inspect the model structure
242266
self._setup_prototype(*args, **kwargs)
@@ -314,10 +338,15 @@ def median(self, params):
314338
return self._constrain(locs)
315339

316340
def quantiles(self, params, quantiles):
317-
quantiles = jnp.array(quantiles)[..., None]
341+
quantiles = jnp.array(quantiles)
318342
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
319343
scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs}
320-
latent = {k: dist.Normal(locs[k], scales[k]).icdf(quantiles) for k in locs}
344+
latent = {
345+
k: dist.Normal(locs[k], scales[k]).icdf(
346+
quantiles.reshape((-1,) + (1,) * jnp.ndim(locs[k]))
347+
)
348+
for k in locs
349+
}
321350
return self._constrain(latent)
322351

323352

@@ -413,12 +442,6 @@ def sample_posterior(self, rng_key, params, sample_shape=()):
413442
return latent_samples
414443

415444
def median(self, params):
416-
"""
417-
Returns the posterior median value of each latent variable.
418-
419-
:return: A dict mapping sample site name to median tensor.
420-
:rtype: dict
421-
"""
422445
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
423446
return locs
424447

@@ -473,12 +496,6 @@ def _sample_latent(self, *args, **kwargs):
473496
)
474497

475498
def __call__(self, *args, **kwargs):
476-
"""
477-
An automatic guide with the same ``*args, **kwargs`` as the base ``model``.
478-
479-
:return: A dict mapping sample site name to sampled value.
480-
:rtype: dict
481-
"""
482499
if self.prototype_trace is None:
483500
# run model to inspect the model structure
484501
self._setup_prototype(*args, **kwargs)
@@ -585,49 +602,11 @@ def get_posterior(self, params):
585602
return dist.TransformedDistribution(base_dist, transform)
586603

587604
def sample_posterior(self, rng_key, params, sample_shape=()):
588-
"""
589-
Get samples from the learned posterior.
590-
591-
:param jax.random.PRNGKey rng_key: random key to be used draw samples.
592-
:param dict params: Current parameters of model and autoguide.
593-
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
594-
method from :class:`~numpyro.infer.svi.SVI`.
595-
:param tuple sample_shape: batch shape of each latent sample, defaults to ().
596-
:return: a dict containing samples drawn the this guide.
597-
:rtype: dict
598-
"""
599605
latent_sample = handlers.substitute(
600606
handlers.seed(self._sample_latent, rng_key), params
601607
)(sample_shape=sample_shape)
602608
return self._unpack_and_constrain(latent_sample, params)
603609

604-
def median(self, params):
605-
"""
606-
Returns the posterior median value of each latent variable.
607-
608-
:param dict params: A dict containing parameter values.
609-
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
610-
method from :class:`~numpyro.infer.svi.SVI`.
611-
:return: A dict mapping sample site name to median tensor.
612-
:rtype: dict
613-
"""
614-
raise NotImplementedError
615-
616-
def quantiles(self, params, quantiles):
617-
"""
618-
Returns posterior quantiles each latent variable. Example::
619-
620-
print(guide.quantiles(opt_state, [0.05, 0.5, 0.95]))
621-
622-
:param dict params: A dict containing parameter values.
623-
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
624-
method from :class:`~numpyro.infer.svi.SVI`.
625-
:param list quantiles: A list of requested quantiles between 0 and 1.
626-
:return: A dict mapping sample site name to a list of quantile values.
627-
:rtype: dict
628-
"""
629-
raise NotImplementedError
630-
631610

632611
class AutoDiagonalNormal(AutoContinuous):
633612
"""

test/infer/test_autoguide.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def body_fn(i, val):
8181
)
8282
assert_allclose(jnp.mean(posterior_samples["beta"], 0), true_coefs, atol=0.05)
8383

84+
if auto_class not in [AutoDelta, AutoIAFNormal, AutoBNAFNormal]:
85+
quantiles = guide.quantiles(params, [0.2, 0.5, 0.8])
86+
assert quantiles["beta"].shape == (3, 2)
87+
8488
# Predictive can be instantiated from posterior samples...
8589
predictive = Predictive(model, posterior_samples=posterior_samples)
8690
predictive_samples = predictive(random.PRNGKey(1), None)

0 commit comments

Comments
 (0)