|
3 | 3 |
|
4 | 4 | from functools import singledispatch |
5 | 5 |
|
6 | | -from jax import random |
| 6 | +from jax import random, tree |
7 | 7 | import jax.numpy as jnp |
8 | 8 |
|
9 | 9 | try: |
@@ -302,22 +302,25 @@ def prior_model(): |
302 | 302 | # replace base samples in jaxns results by transformed samples |
303 | 303 | self._results = results._replace(samples=samples) |
304 | 304 |
|
305 | | - def get_samples(self, rng_key, num_samples): |
| 305 | + def get_samples(self, rng_key, num_samples, *, group_by_chain=False): |
306 | 306 | """ |
307 | 307 | Draws samples from the weighted samples collected from the run. |
308 | 308 |
|
309 | 309 | :param random.PRNGKey rng_key: Random number generator key to be used to draw samples. |
310 | 310 | :param int num_samples: The number of samples. |
| 311 | + :param bool group_by_chain: If True, a leading chain dimension of 1 is added to the output arrays. |
311 | 312 | :return: a dict of posterior samples |
312 | 313 | """ |
313 | 314 | if self._results is None: |
314 | 315 | raise RuntimeError( |
315 | 316 | "NestedSampler.run(...) method should be called first to obtain results." |
316 | 317 | ) |
317 | 318 | weighted_samples, sample_weights = self.get_weighted_samples() |
318 | | - return resample( |
| 319 | + samples = resample( |
319 | 320 | rng_key, weighted_samples, sample_weights, S=num_samples, replace=True |
320 | 321 | ) |
| 322 | + chain_dim_sel = None if group_by_chain else Ellipsis |
| 323 | + return tree.map(lambda x: x[chain_dim_sel], samples) |
321 | 324 |
|
322 | 325 | def get_weighted_samples(self): |
323 | 326 | """ |
|
0 commit comments