Skip to content

Commit af68b5f

Browse files
authored
Unify sample_posterior() signatures (#1979)
1 parent fa2ecb3 commit af68b5f

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

numpyro/infer/autoguide.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __call__(self, *args, **kwargs):
132132
raise NotImplementedError
133133

134134
@abstractmethod
135-
def sample_posterior(self, rng_key, params, *, sample_shape=()):
135+
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
136136
"""
137137
Generate samples from the approximate posterior over the latent
138138
sites in the model.
@@ -141,7 +141,9 @@ def sample_posterior(self, rng_key, params, *, sample_shape=()):
141141
:param dict params: Current parameters of model and autoguide.
142142
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
143143
method from :class:`~numpyro.infer.svi.SVI`.
144+
:param args: Arguments to be provided to the model / guide.
144145
:param tuple sample_shape: sample shape of each latent site, defaults to ().
146+
:param kwargs: Keyword arguments to be provided to the model / guide.
145147
:return: a dict containing samples drawn the this guide.
146148
:rtype: dict
147149
"""
@@ -317,18 +319,11 @@ def __iter__(self):
317319
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
318320
result = {}
319321
for part in self._guides:
320-
# TODO: remove this when sample_posterior() signatures are consistent
321-
# we know part is not AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
322-
if isinstance(part, numpyro.infer.autoguide.AutoDelta):
323-
result.update(
324-
part.sample_posterior(
325-
rng_key, params, *args, sample_shape=sample_shape, **kwargs
326-
)
327-
)
328-
else:
329-
result.update(
330-
part.sample_posterior(rng_key, params, sample_shape=sample_shape)
322+
result.update(
323+
part.sample_posterior(
324+
rng_key, params, *args, sample_shape=sample_shape, **kwargs
331325
)
326+
)
332327
return result
333328

334329
def median(self, params):
@@ -469,7 +464,7 @@ def _constrain(self, latent_samples):
469464
else:
470465
return self._postprocess_fn(latent_samples)
471466

472-
def sample_posterior(self, rng_key, params, *, sample_shape=()):
467+
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
473468
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
474469
scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs}
475470
with handlers.seed(rng_seed=rng_key):
@@ -810,7 +805,7 @@ def get_posterior(self, params):
810805
transform = self.get_transform(params)
811806
return dist.TransformedDistribution(base_dist, transform)
812807

813-
def sample_posterior(self, rng_key, params, *, sample_shape=()):
808+
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
814809
latent_sample = handlers.substitute(
815810
handlers.seed(self._sample_latent, rng_key), params
816811
)(sample_shape=sample_shape)
@@ -999,7 +994,7 @@ def scan_body(carry, eps_beta):
999994

1000995
return z
1001996

1002-
def sample_posterior(self, rng_key, params, *, sample_shape=()):
997+
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
1003998
def _single_sample(_rng_key):
1004999
latent_sample = handlers.substitute(
10051000
handlers.seed(self._sample_latent, _rng_key), params
@@ -2175,7 +2170,7 @@ def get_posterior(self, params):
21752170
transform = self.get_transform(params)
21762171
return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril)
21772172

2178-
def sample_posterior(self, rng_key, params, *, sample_shape=()):
2173+
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
21792174
latent_sample = self.get_posterior(params).sample(rng_key, sample_shape)
21802175
return self._unpack_and_constrain(latent_sample, params)
21812176

0 commit comments

Comments
 (0)