@@ -142,6 +142,40 @@ def _sample_stats_to_xarray(posterior):
142
142
return data
143
143
144
144
145
+ def _blackjax_stats_to_dict (sample_stats , potential_energy ) -> Dict :
146
+ """Extract compatible stats from blackjax NUTS sampler
147
+ with PyMC/Arviz naming conventions.
148
+
149
+ Parameters
150
+ ----------
151
+ sample_stats: NUTSInfo
152
+ Blackjax NUTSInfo object containing sampler statistics
153
+ potential_energy: ArrayLike
154
+ Potential energy values of sampled positions.
155
+
156
+ Returns
157
+ -------
158
+ Dict[str, ArrayLike]
159
+ Dictionary of sampler statistics.
160
+ """
161
+ rename_key = {
162
+ "is_divergent" : "diverging" ,
163
+ "energy" : "energy" ,
164
+ "num_trajectory_expansions" : "tree_depth" ,
165
+ "num_integration_steps" : "n_steps" ,
166
+ "acceptance_rate" : "acceptance_rate" , # naming here is
167
+ "acceptance_probability" : "acceptance_rate" , # depending on blackjax version
168
+ }
169
+ converted_stats = {}
170
+ converted_stats ["lp" ] = potential_energy
171
+ for old_name , new_name in rename_key .items ():
172
+ value = getattr (sample_stats , old_name , None )
173
+ if value is None :
174
+ continue
175
+ converted_stats [new_name ] = value
176
+ return converted_stats
177
+
178
+
145
179
def _get_log_likelihood (model : Model , samples , backend = None ) -> Dict :
146
180
"""Compute log-likelihood for all observations"""
147
181
elemwise_logp = model .logp (model .observed_RVs , sum = False )
@@ -360,9 +394,9 @@ def sample_blackjax_nuts(
360
394
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
361
395
)
362
396
363
- states , _ = map_fn (get_posterior_samples )(keys , init_params )
397
+ states , stats = map_fn (get_posterior_samples )(keys , init_params )
364
398
raw_mcmc_samples = states .position
365
-
399
+ potential_energy = states . potential_energy
366
400
tic3 = datetime .now ()
367
401
print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
368
402
@@ -372,7 +406,7 @@ def sample_blackjax_nuts(
372
406
* jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ])
373
407
)
374
408
mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
375
-
409
+ mcmc_stats = _blackjax_stats_to_dict ( stats , potential_energy )
376
410
tic4 = datetime .now ()
377
411
print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
378
412
@@ -406,6 +440,7 @@ def sample_blackjax_nuts(
406
440
log_likelihood = log_likelihood ,
407
441
observed_data = find_observations (model ),
408
442
constant_data = find_constants (model ),
443
+ sample_stats = mcmc_stats ,
409
444
coords = coords ,
410
445
dims = dims ,
411
446
attrs = make_attrs (attrs , library = blackjax ),
0 commit comments