@@ -132,7 +132,7 @@ def __call__(self, *args, **kwargs):
132
132
raise NotImplementedError
133
133
134
134
@abstractmethod
135
- def sample_posterior (self , rng_key , params , * , sample_shape = ()):
135
+ def sample_posterior (self , rng_key , params , * args , sample_shape = (), ** kwargs ):
136
136
"""
137
137
Generate samples from the approximate posterior over the latent
138
138
sites in the model.
@@ -141,7 +141,9 @@ def sample_posterior(self, rng_key, params, *, sample_shape=()):
141
141
:param dict params: Current parameters of model and autoguide.
142
142
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
143
143
method from :class:`~numpyro.infer.svi.SVI`.
144
+ :param args: Arguments to be provided to the model / guide.
144
145
:param tuple sample_shape: sample shape of each latent site, defaults to ().
146
+ :param kwargs: Keyword arguments to be provided to the model / guide.
145
147
:return: a dict containing samples drawn the this guide.
146
148
:rtype: dict
147
149
"""
@@ -317,18 +319,11 @@ def __iter__(self):
317
319
def sample_posterior (self , rng_key , params , * args , sample_shape = (), ** kwargs ):
318
320
result = {}
319
321
for part in self ._guides :
320
- # TODO: remove this when sample_posterior() signatures are consistent
321
- # we know part is not AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
322
- if isinstance (part , numpyro .infer .autoguide .AutoDelta ):
323
- result .update (
324
- part .sample_posterior (
325
- rng_key , params , * args , sample_shape = sample_shape , ** kwargs
326
- )
327
- )
328
- else :
329
- result .update (
330
- part .sample_posterior (rng_key , params , sample_shape = sample_shape )
322
+ result .update (
323
+ part .sample_posterior (
324
+ rng_key , params , * args , sample_shape = sample_shape , ** kwargs
331
325
)
326
+ )
332
327
return result
333
328
334
329
def median (self , params ):
@@ -469,7 +464,7 @@ def _constrain(self, latent_samples):
469
464
else :
470
465
return self ._postprocess_fn (latent_samples )
471
466
472
- def sample_posterior (self , rng_key , params , * , sample_shape = ()):
467
+ def sample_posterior (self , rng_key , params , * args , sample_shape = (), ** kwargs ):
473
468
locs = {k : params ["{}_{}_loc" .format (k , self .prefix )] for k in self ._init_locs }
474
469
scales = {k : params ["{}_{}_scale" .format (k , self .prefix )] for k in locs }
475
470
with handlers .seed (rng_seed = rng_key ):
@@ -810,7 +805,7 @@ def get_posterior(self, params):
810
805
transform = self .get_transform (params )
811
806
return dist .TransformedDistribution (base_dist , transform )
812
807
813
- def sample_posterior (self , rng_key , params , * , sample_shape = ()):
808
+ def sample_posterior (self , rng_key , params , * args , sample_shape = (), ** kwargs ):
814
809
latent_sample = handlers .substitute (
815
810
handlers .seed (self ._sample_latent , rng_key ), params
816
811
)(sample_shape = sample_shape )
@@ -999,7 +994,7 @@ def scan_body(carry, eps_beta):
999
994
1000
995
return z
1001
996
1002
- def sample_posterior (self , rng_key , params , * , sample_shape = ()):
997
+ def sample_posterior (self , rng_key , params , * args , sample_shape = (), ** kwargs ):
1003
998
def _single_sample (_rng_key ):
1004
999
latent_sample = handlers .substitute (
1005
1000
handlers .seed (self ._sample_latent , _rng_key ), params
@@ -2175,7 +2170,7 @@ def get_posterior(self, params):
2175
2170
transform = self .get_transform (params )
2176
2171
return dist .MultivariateNormal (transform .loc , scale_tril = transform .scale_tril )
2177
2172
2178
- def sample_posterior (self , rng_key , params , * , sample_shape = ()):
2173
+ def sample_posterior (self , rng_key , params , * args , sample_shape = (), ** kwargs ):
2179
2174
latent_sample = self .get_posterior (params ).sample (rng_key , sample_shape )
2180
2175
return self ._unpack_and_constrain (latent_sample , params )
2181
2176
0 commit comments