1313_log = logging .getLogger (__name__ )
1414
1515
16+ class NestedToMCMCAdapter :
17+ """
18+ Adapter to convert a NestedSampler object into an MCMC-compatible interface.
19+
20+ This class reshapes posterior samples from a NestedSampler into a chain-and-draw
21+ structure expected by MCMC workflows, providing compatibility with downstream
22+ tools like ArviZ for posterior analysis.
23+
24+ Parameters
25+ ----------
26+ nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
27+ The NestedSampler object containing posterior samples.
28+ rng_key : jax.random.PRNGKey
29+ The random key used for sampling.
30+ num_samples : int
31+ The total number of posterior samples to draw.
32+ num_chains : int, optional
33+ The number of artificial chains to create for MCMC compatibility (default is 1).
34+ *args : tuple
35+ Additional positional arguments required by the model (e.g., data, labels).
36+ **kwargs : dict
37+ Additional keyword arguments required by the model.
38+
39+ Attributes
40+ ----------
41+ samples : dict
42+ Reshaped posterior samples organized by variable name.
43+ thinning : int
44+ Dummy thinning attribute for compatibility with MCMC.
45+ sampler : NestedToMCMCAdapter
46+ Mimics the sampler attribute of an MCMC object.
47+ model : callable
48+ The probabilistic model used in the NestedSampler.
49+ _args : tuple
50+ Positional arguments passed to the model.
51+ _kwargs : dict
52+ Keyword arguments passed to the model.
53+
54+ Methods
55+ -------
56+ get_samples(group_by_chain=True)
57+ Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
58+ get_extra_fields(group_by_chain=True)
59+ Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
60+ """
61+
62+ def __init__ (self , nested_sampler , rng_key , num_samples , * args , num_chains = 1 , ** kwargs ):
63+ self .nested_sampler = nested_sampler
64+ self .rng_key = rng_key
65+ self .num_samples = num_samples
66+ self .num_chains = num_chains
67+ self .samples = self ._reshape_samples ()
68+ self .thinning = 1
69+ self .sampler = self
70+ self .model = nested_sampler .model
71+ self ._args = args
72+ self ._kwargs = kwargs
73+
74+ def _reshape_samples (self ):
75+ raw_samples = self .nested_sampler .get_samples (self .rng_key , self .num_samples )
76+ samples_per_chain = self .num_samples // self .num_chains
77+ return {
78+ k : np .reshape (
79+ v [: samples_per_chain * self .num_chains ],
80+ (self .num_chains , samples_per_chain , * v .shape [1 :]),
81+ )
82+ for k , v in raw_samples .items ()
83+ }
84+
85+ def get_samples (self , group_by_chain = True ):
86+ if group_by_chain :
87+ return self .samples
88+ else :
89+ # Flatten chains into a single dimension
90+ return {k : v .reshape (- 1 , * v .shape [2 :]) for k , v in self .samples .items ()}
91+
92+ def get_extra_fields (self , group_by_chain = True ):
93+ # Generate dummy fields since NestedSampler does not produce these
94+ n_chains = self .num_chains
95+ n_samples = self .num_samples // self .num_chains
96+
97+ # Create dummy values for extra fields
98+ extra_fields = {
99+ "accept_prob" : np .full ((n_chains , n_samples ), 1.0 ), # Assume all proposals are accepted
100+ "step_size" : np .full ((n_chains , n_samples ), 0.1 ), # Dummy step size
101+ "num_steps" : np .full ((n_chains , n_samples ), 10 ), # Dummy number of steps
102+ }
103+
104+ if not group_by_chain :
105+ # Flatten the chains into a single dimension
106+ extra_fields = {k : v .reshape (- 1 , * v .shape [2 :]) for k , v in extra_fields .items ()}
107+
108+ return extra_fields
109+
110+
16111class NumPyroConverter :
17112 """Encapsulate NumPyro specific logic."""
18113
@@ -37,6 +132,10 @@ def __init__(
37132 dims = None ,
38133 pred_dims = None ,
39134 num_chains = 1 ,
135+ rng_key = None ,
136+ num_samples = 1000 ,
137+ data = None ,
138+ labels = None ,
40139 ):
41140 """Convert NumPyro data into an InferenceData object.
42141
@@ -68,6 +167,15 @@ def __init__(
68167 import numpyro
69168
70169 self .posterior = posterior
170+ self .rng_key = rng_key
171+ self .num_samples = num_samples
172+
173+ if isinstance (posterior , numpyro .contrib .nested_sampling .NestedSampler ):
174+ posterior = NestedToMCMCAdapter (
175+ posterior , rng_key , num_samples , num_chains = num_chains , data = data , labels = labels
176+ )
177+ self .posterior = posterior
178+
71179 self .prior = jax .device_get (prior )
72180 self .posterior_predictive = jax .device_get (posterior_predictive )
73181 self .predictions = predictions
@@ -340,6 +448,10 @@ def from_numpyro(
340448 dims = None ,
341449 pred_dims = None ,
342450 num_chains = 1 ,
451+ rng_key = None ,
452+ num_samples = 1000 ,
453+ data = None ,
454+ labels = None ,
343455):
344456 """Convert NumPyro data into an InferenceData object.
345457
@@ -383,4 +495,8 @@ def from_numpyro(
383495 dims = dims ,
384496 pred_dims = pred_dims ,
385497 num_chains = num_chains ,
498+ rng_key = rng_key ,
499+ num_samples = num_samples ,
500+ data = data ,
501+ labels = labels ,
386502 ).to_inference_data ()
0 commit comments