1212
1313_log = logging .getLogger (__name__ )
1414
15+ class NestedToMCMCAdapter :
16+ """
17+ Adapter to convert a NestedSampler object into an MCMC-compatible interface.
18+
19+ This class reshapes posterior samples from a NestedSampler into a chain-and-draw
20+ structure expected by MCMC workflows, providing compatibility with downstream
21+ tools like ArviZ for posterior analysis.
22+
23+ Parameters
24+ ----------
25+ nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
26+ The NestedSampler object containing posterior samples.
27+ rng_key : jax.random.PRNGKey
28+ The random key used for sampling.
29+ num_samples : int
30+ The total number of posterior samples to draw.
31+ num_chains : int, optional
32+ The number of artificial chains to create for MCMC compatibility (default is 1).
33+ *args : tuple
34+ Additional positional arguments required by the model (e.g., data, labels).
35+ **kwargs : dict
36+ Additional keyword arguments required by the model.
37+
38+ Attributes
39+ ----------
40+ samples : dict
41+ Reshaped posterior samples organized by variable name.
42+ thinning : int
43+ Dummy thinning attribute for compatibility with MCMC.
44+ sampler : NestedToMCMCAdapter
45+ Mimics the sampler attribute of an MCMC object.
46+ model : callable
47+ The probabilistic model used in the NestedSampler.
48+ _args : tuple
49+ Positional arguments passed to the model.
50+ _kwargs : dict
51+ Keyword arguments passed to the model.
52+
53+ Methods
54+ -------
55+ get_samples(group_by_chain=True)
56+ Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
57+ get_extra_fields(group_by_chain=True)
58+ Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
59+ """
60+ def __init__ (self , nested_sampler , rng_key , num_samples , * args , num_chains = 1 , ** kwargs ):
61+ self .nested_sampler = nested_sampler
62+ self .rng_key = rng_key
63+ self .num_samples = num_samples
64+ self .num_chains = num_chains
65+ self .samples = self ._reshape_samples ()
66+ self .thinning = 1
67+ self .sampler = self
68+ self .model = nested_sampler .model
69+ self ._args = args
70+ self ._kwargs = kwargs
71+
72+ def _reshape_samples (self ):
73+ raw_samples = self .nested_sampler .get_samples (self .rng_key , self .num_samples )
74+ samples_per_chain = self .num_samples // self .num_chains
75+ return {
76+ k : np .reshape (v [:samples_per_chain * self .num_chains ],
77+ (self .num_chains , samples_per_chain , * v .shape [1 :]))
78+ for k , v in raw_samples .items ()
79+ }
80+
81+ def get_samples (self , group_by_chain = True ):
82+ if group_by_chain :
83+ return self .samples
84+ else :
85+ # Flatten chains into a single dimension
86+ return {k : v .reshape (- 1 , * v .shape [2 :]) for k , v in self .samples .items ()}
87+
88+ def get_extra_fields (self , group_by_chain = True ):
89+ # Generate dummy fields since NestedSampler does not produce these
90+ n_chains = self .num_chains
91+ n_samples = self .num_samples // self .num_chains
92+
93+ # Create dummy values for extra fields
94+ extra_fields = {
95+ "accept_prob" : np .full ((n_chains , n_samples ), 1.0 ), # Assume all proposals are accepted
96+ "step_size" : np .full ((n_chains , n_samples ), 0.1 ), # Dummy step size
97+ "num_steps" : np .full ((n_chains , n_samples ), 10 ), # Dummy number of steps
98+ }
99+
100+ if not group_by_chain :
101+ # Flatten the chains into a single dimension
102+ extra_fields = {k : v .reshape (- 1 , * v .shape [2 :]) for k , v in extra_fields .items ()}
103+
104+ return extra_fields
15105
16106class NumPyroConverter :
17107 """Encapsulate NumPyro specific logic."""
@@ -37,6 +127,10 @@ def __init__(
37127 dims = None ,
38128 pred_dims = None ,
39129 num_chains = 1 ,
130+ rng_key = None ,
131+ num_samples = 1000 ,
132+ data = None ,
133+ labels = None ,
40134 ):
41135 """Convert NumPyro data into an InferenceData object.
42136
@@ -68,6 +162,14 @@ def __init__(
68162 import numpyro
69163
70164 self .posterior = posterior
165+ self .rng_key = rng_key
166+ self .num_samples = num_samples
167+
168+ if isinstance (posterior , numpyro .contrib .nested_sampling .NestedSampler ):
169+ posterior = NestedToMCMCAdapter (posterior , rng_key , num_samples ,
170+ num_chains = num_chains , data = data , labels = labels )
171+ self .posterior = posterior
172+
71173 self .prior = jax .device_get (prior )
72174 self .posterior_predictive = jax .device_get (posterior_predictive )
73175 self .predictions = predictions
@@ -340,6 +442,10 @@ def from_numpyro(
340442 dims = None ,
341443 pred_dims = None ,
342444 num_chains = 1 ,
445+ rng_key = None ,
446+ num_samples = 1000 ,
447+ data = None ,
448+ labels = None ,
343449):
344450 """Convert NumPyro data into an InferenceData object.
345451
@@ -383,4 +489,8 @@ def from_numpyro(
383489 dims = dims ,
384490 pred_dims = pred_dims ,
385491 num_chains = num_chains ,
492+ rng_key = rng_key ,
493+ num_samples = num_samples ,
494+ data = data ,
495+ labels = labels ,
386496 ).to_inference_data ()
0 commit comments