Skip to content

Commit aa860f7

Browse files
authored
feat: add api compat arg to NS.get_samples (#1880)
* feat: add api compat arg to NS.get_samples Signed-off-by: nstarman <[email protected]> * feat: add leading dimension Signed-off-by: nstarman <[email protected]> * fix: lint errors Signed-off-by: nstarman <[email protected]> --------- Signed-off-by: nstarman <[email protected]>
1 parent 41755a1 commit aa860f7

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

numpyro/contrib/nested_sampling.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from functools import singledispatch
55

6-
from jax import random
6+
from jax import random, tree
77
import jax.numpy as jnp
88

99
try:
@@ -302,22 +302,25 @@ def prior_model():
302302
# replace base samples in jaxns results by transformed samples
303303
self._results = results._replace(samples=samples)
304304

305-
def get_samples(self, rng_key, num_samples):
305+
def get_samples(self, rng_key, num_samples, *, group_by_chain=False):
306306
"""
307307
Draws samples from the weighted samples collected from the run.
308308
309309
:param random.PRNGKey rng_key: Random number generator key to be used to draw samples.
310310
:param int num_samples: The number of samples.
311+
:param bool group_by_chain: If True, a leading chain dimension of 1 is added to the output arrays.
311312
:return: a dict of posterior samples
312313
"""
313314
if self._results is None:
314315
raise RuntimeError(
315316
"NestedSampler.run(...) method should be called first to obtain results."
316317
)
317318
weighted_samples, sample_weights = self.get_weighted_samples()
318-
return resample(
319+
samples = resample(
319320
rng_key, weighted_samples, sample_weights, S=num_samples, replace=True
320321
)
322+
chain_dim_sel = None if group_by_chain else Ellipsis
323+
return tree.map(lambda x: x[chain_dim_sel], samples)
321324

322325
def get_weighted_samples(self):
323326
"""

0 commit comments

Comments
 (0)