Skip to content

Commit fa7934f

Browse files
committed
Add NestedToMCMCAdapter to enable compatibility with ArviZ and MCMC workflows (#2391)
1 parent 529d795 commit fa7934f

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

arviz/data/io_numpyro.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,96 @@
1212

1313
_log = logging.getLogger(__name__)
1414

15+
class NestedToMCMCAdapter:
16+
"""
17+
Adapter to convert a NestedSampler object into an MCMC-compatible interface.
18+
19+
This class reshapes posterior samples from a NestedSampler into a chain-and-draw
20+
structure expected by MCMC workflows, providing compatibility with downstream
21+
tools like ArviZ for posterior analysis.
22+
23+
Parameters
24+
----------
25+
nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
26+
The NestedSampler object containing posterior samples.
27+
rng_key : jax.random.PRNGKey
28+
The random key used for sampling.
29+
num_samples : int
30+
The total number of posterior samples to draw.
31+
num_chains : int, optional
32+
The number of artificial chains to create for MCMC compatibility (default is 1).
33+
*args : tuple
34+
Additional positional arguments required by the model (e.g., data, labels).
35+
**kwargs : dict
36+
Additional keyword arguments required by the model.
37+
38+
Attributes
39+
----------
40+
samples : dict
41+
Reshaped posterior samples organized by variable name.
42+
thinning : int
43+
Dummy thinning attribute for compatibility with MCMC.
44+
sampler : NestedToMCMCAdapter
45+
Mimics the sampler attribute of an MCMC object.
46+
model : callable
47+
The probabilistic model used in the NestedSampler.
48+
_args : tuple
49+
Positional arguments passed to the model.
50+
_kwargs : dict
51+
Keyword arguments passed to the model.
52+
53+
Methods
54+
-------
55+
get_samples(group_by_chain=True)
56+
Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
57+
get_extra_fields(group_by_chain=True)
58+
Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
59+
"""
60+
def __init__(self, nested_sampler, rng_key, num_samples, *args, num_chains=1, **kwargs):
61+
self.nested_sampler = nested_sampler
62+
self.rng_key = rng_key
63+
self.num_samples = num_samples
64+
self.num_chains = num_chains
65+
self.samples = self._reshape_samples()
66+
self.thinning = 1
67+
self.sampler = self
68+
self.model = nested_sampler.model
69+
self._args = args
70+
self._kwargs = kwargs
71+
72+
def _reshape_samples(self):
73+
raw_samples = self.nested_sampler.get_samples(self.rng_key, self.num_samples)
74+
samples_per_chain = self.num_samples // self.num_chains
75+
return {
76+
k: np.reshape(v[:samples_per_chain * self.num_chains],
77+
(self.num_chains, samples_per_chain, *v.shape[1:]))
78+
for k, v in raw_samples.items()
79+
}
80+
81+
def get_samples(self, group_by_chain=True):
82+
if group_by_chain:
83+
return self.samples
84+
else:
85+
# Flatten chains into a single dimension
86+
return {k: v.reshape(-1, *v.shape[2:]) for k, v in self.samples.items()}
87+
88+
def get_extra_fields(self, group_by_chain=True):
89+
# Generate dummy fields since NestedSampler does not produce these
90+
n_chains = self.num_chains
91+
n_samples = self.num_samples // self.num_chains
92+
93+
# Create dummy values for extra fields
94+
extra_fields = {
95+
"accept_prob": np.full((n_chains, n_samples), 1.0), # Assume all proposals are accepted
96+
"step_size": np.full((n_chains, n_samples), 0.1), # Dummy step size
97+
"num_steps": np.full((n_chains, n_samples), 10), # Dummy number of steps
98+
}
99+
100+
if not group_by_chain:
101+
# Flatten the chains into a single dimension
102+
extra_fields = {k: v.reshape(-1, *v.shape[2:]) for k, v in extra_fields.items()}
103+
104+
return extra_fields
15105

16106
class NumPyroConverter:
17107
"""Encapsulate NumPyro specific logic."""
@@ -37,6 +127,10 @@ def __init__(
37127
dims=None,
38128
pred_dims=None,
39129
num_chains=1,
130+
rng_key=None,
131+
num_samples=1000,
132+
data=None,
133+
labels=None,
40134
):
41135
"""Convert NumPyro data into an InferenceData object.
42136
@@ -68,6 +162,14 @@ def __init__(
68162
import numpyro
69163

70164
self.posterior = posterior
165+
self.rng_key = rng_key
166+
self.num_samples = num_samples
167+
168+
if isinstance(posterior, numpyro.contrib.nested_sampling.NestedSampler):
169+
posterior = NestedToMCMCAdapter(posterior, rng_key, num_samples,
170+
num_chains=num_chains, data=data, labels=labels)
171+
self.posterior = posterior
172+
71173
self.prior = jax.device_get(prior)
72174
self.posterior_predictive = jax.device_get(posterior_predictive)
73175
self.predictions = predictions
@@ -340,6 +442,10 @@ def from_numpyro(
340442
dims=None,
341443
pred_dims=None,
342444
num_chains=1,
445+
rng_key=None,
446+
num_samples=1000,
447+
data=None,
448+
labels=None,
343449
):
344450
"""Convert NumPyro data into an InferenceData object.
345451
@@ -383,4 +489,8 @@ def from_numpyro(
383489
dims=dims,
384490
pred_dims=pred_dims,
385491
num_chains=num_chains,
492+
rng_key=rng_key,
493+
num_samples=num_samples,
494+
data=data,
495+
labels=labels,
386496
).to_inference_data()

0 commit comments

Comments
 (0)