Skip to content

Commit 3f9d2e2

Browse files
Sample stats for blackjax nuts (#6264)
* tests for added sample statistics * sample stats test with more draws * redesigned test for older blackjax version * record blackjax sample stats * moved sample stats argument to partial call
1 parent f577c2c commit 3f9d2e2

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

pymc/sampling/jax.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,40 @@ def _sample_stats_to_xarray(posterior):
142142
return data
143143

144144

145+
def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
146+
"""Extract compatible stats from blackjax NUTS sampler
147+
with PyMC/Arviz naming conventions.
148+
149+
Parameters
150+
----------
151+
sample_stats: NUTSInfo
152+
Blackjax NUTSInfo object containing sampler statistics
153+
potential_energy: ArrayLike
154+
Potential energy values of sampled positions.
155+
156+
Returns
157+
-------
158+
Dict[str, ArrayLike]
159+
Dictionary of sampler statistics.
160+
"""
161+
rename_key = {
162+
"is_divergent": "diverging",
163+
"energy": "energy",
164+
"num_trajectory_expansions": "tree_depth",
165+
"num_integration_steps": "n_steps",
166+
"acceptance_rate": "acceptance_rate", # naming here is
167+
"acceptance_probability": "acceptance_rate", # depending on blackjax version
168+
}
169+
converted_stats = {}
170+
converted_stats["lp"] = potential_energy
171+
for old_name, new_name in rename_key.items():
172+
value = getattr(sample_stats, old_name, None)
173+
if value is None:
174+
continue
175+
converted_stats[new_name] = value
176+
return converted_stats
177+
178+
145179
def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
146180
"""Compute log-likelihood for all observations"""
147181
elemwise_logp = model.logp(model.observed_RVs, sum=False)
@@ -360,9 +394,9 @@ def sample_blackjax_nuts(
360394
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
361395
)
362396

363-
states, _ = map_fn(get_posterior_samples)(keys, init_params)
397+
states, stats = map_fn(get_posterior_samples)(keys, init_params)
364398
raw_mcmc_samples = states.position
365-
399+
potential_energy = states.potential_energy
366400
tic3 = datetime.now()
367401
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
368402

@@ -372,7 +406,7 @@ def sample_blackjax_nuts(
372406
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
373407
)
374408
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
375-
409+
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
376410
tic4 = datetime.now()
377411
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
378412

@@ -406,6 +440,7 @@ def sample_blackjax_nuts(
406440
log_likelihood=log_likelihood,
407441
observed_data=find_observations(model),
408442
constant_data=find_constants(model),
443+
sample_stats=mcmc_stats,
409444
coords=coords,
410445
dims=dims,
411446
attrs=make_attrs(attrs, library=blackjax),

pymc/tests/sampling/test_jax.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,52 @@ def test_numpyro_nuts_kwargs_are_used(mocked: mock.MagicMock):
365365
assert nuts_sampler._adapt_step_size == adapt_step_size
366366
assert nuts_sampler._adapt_mass_matrix
367367
assert nuts_sampler._target_accept_prob == target_accept
368+
369+
370+
@pytest.mark.parametrize(
371+
"sampler_name",
372+
[
373+
"sample_blackjax_nuts",
374+
"sample_numpyro_nuts",
375+
],
376+
)
377+
def test_idata_contains_stats(sampler_name: str):
378+
"""Tests whether sampler statistics were written to sample_stats
379+
group of InferenceData"""
380+
if sampler_name == "sample_blackjax_nuts":
381+
sampler = sample_blackjax_nuts
382+
elif sampler_name == "sample_numpyro_nuts":
383+
sampler = sample_numpyro_nuts
384+
385+
with pm.Model():
386+
pm.Normal("a")
387+
idata = sampler(tune=50, draws=50)
388+
389+
stats = idata.get("sample_stats")
390+
assert stats is not None
391+
n_chains = stats.dims["chain"]
392+
n_draws = stats.dims["draw"]
393+
394+
# Stats vars expected for both samplers
395+
expected_stat_vars = {
396+
"acceptance_rate": (n_chains, n_draws),
397+
"diverging": (n_chains, n_draws),
398+
"energy": (n_chains, n_draws),
399+
"tree_depth": (n_chains, n_draws),
400+
"lp": (n_chains, n_draws),
401+
}
402+
# Stats only expected for blackjax nuts
403+
if sampler_name == "sample_blackjax_nuts":
404+
blackjax_special_vars = {}
405+
stat_vars = expected_stat_vars | blackjax_special_vars
406+
# Stats only expected for numpyro nuts
407+
elif sampler_name == "sample_numpyro_nuts":
408+
numpyro_special_vars = {
409+
"step_size": (n_chains, n_draws),
410+
"n_steps": (n_chains, n_draws),
411+
}
412+
stat_vars = expected_stat_vars | numpyro_special_vars
413+
# test existence and dimensionality
414+
for stat_var, stat_var_dims in stat_vars.items():
415+
assert stat_var in stats.variables
416+
assert stats.get(stat_var).values.shape == stat_var_dims

0 commit comments

Comments
 (0)