Skip to content

Commit f231d13

Browse files
added postprocessing_chunks option to sample_blackjax_nuts and sample… (#6388)
* added postprocessing_chunks option to sample_blackjax_nuts and sample_numpyro_nuts * make chunking optional, add chunking argument to _get_loglikelihood * update tests for jax postprocessing chunking * update docs * Run pre-commit Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 98ccc68 commit f231d13

File tree

2 files changed

+51
-9
lines changed

2 files changed

+51
-9
lines changed

pymc/sampling/jax.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytensor.tensor as at
2222

2323
from arviz.data.base import make_attrs
24+
from jax.experimental.maps import SerialLoop, xmap
2425
from pytensor.compile import SharedVariable, Supervisor, mode
2526
from pytensor.graph.basic import graph_inputs
2627
from pytensor.graph.fg import FunctionGraph
@@ -143,6 +144,27 @@ def _sample_stats_to_xarray(posterior):
143144
return data
144145

145146

147+
def _postprocess_samples(
148+
jax_fn: List[TensorVariable],
149+
raw_mcmc_samples: List[TensorVariable],
150+
postprocessing_backend: str,
151+
num_chunks: Optional[int] = None,
152+
) -> List[TensorVariable]:
153+
if num_chunks is not None:
154+
loop = xmap(
155+
jax_fn,
156+
in_axes=["chain", "samples", ...],
157+
out_axes=["chain", "samples", ...],
158+
axis_resources={"samples": SerialLoop(num_chunks)},
159+
)
160+
f = xmap(loop, in_axes=[...], out_axes=[...])
161+
return f(*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0]))
162+
else:
163+
return jax.vmap(jax.vmap(jax_fn))(
164+
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
165+
)
166+
167+
146168
def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
147169
"""Extract compatible stats from blackjax NUTS sampler
148170
with PyMC/Arviz naming conventions.
@@ -177,11 +199,13 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
177199
return converted_stats
178200

179201

180-
def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
202+
def _get_log_likelihood(
203+
model: Model, samples, backend=None, num_chunks: Optional[int] = None
204+
) -> Dict:
181205
"""Compute log-likelihood for all observations"""
182206
elemwise_logp = model.logp(model.observed_RVs, sum=False)
183207
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp)
184-
result = jax.vmap(jax.vmap(jax_fn))(*jax.device_put(samples, jax.devices(backend)[0]))
208+
result = _postprocess_samples(jax_fn, samples, backend, num_chunks=num_chunks)
185209
return {v.name: r for v, r in zip(model.observed_RVs, result)}
186210

187211

@@ -275,6 +299,7 @@ def sample_blackjax_nuts(
275299
keep_untransformed: bool = False,
276300
chain_method: str = "parallel",
277301
postprocessing_backend: Optional[str] = None,
302+
postprocessing_chunks: Optional[int] = None,
278303
idata_kwargs: Optional[Dict[str, Any]] = None,
279304
) -> az.InferenceData:
280305
"""
@@ -314,6 +339,10 @@ def sample_blackjax_nuts(
314339
"vectorized".
315340
postprocessing_backend : str, optional
316341
Specify how postprocessing should be computed. gpu or cpu
342+
postprocessing_chunks: Optional[int], default None
343+
Specify the number of chunks the postprocessing should be computed in. More
344+
chunks reduces memory usage at the cost of losing some vectorization, None
345+
uses jax.vmap
317346
idata_kwargs : dict, optional
318347
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
319348
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -400,8 +429,8 @@ def sample_blackjax_nuts(
400429

401430
print("Transforming variables...", file=sys.stdout)
402431
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
403-
result = jax.vmap(jax.vmap(jax_fn))(
404-
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
432+
result = _postprocess_samples(
433+
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
405434
)
406435
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
407436
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
@@ -417,7 +446,10 @@ def sample_blackjax_nuts(
417446
tic5 = datetime.now()
418447
print("Computing Log Likelihood...", file=sys.stdout)
419448
log_likelihood = _get_log_likelihood(
420-
model, raw_mcmc_samples, backend=postprocessing_backend
449+
model,
450+
raw_mcmc_samples,
451+
backend=postprocessing_backend,
452+
num_chunks=postprocessing_chunks,
421453
)
422454
tic6 = datetime.now()
423455
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
@@ -478,6 +510,7 @@ def sample_numpyro_nuts(
478510
keep_untransformed: bool = False,
479511
chain_method: str = "parallel",
480512
postprocessing_backend: Optional[str] = None,
513+
postprocessing_chunks: Optional[int] = None,
481514
idata_kwargs: Optional[Dict] = None,
482515
nuts_kwargs: Optional[Dict] = None,
483516
) -> az.InferenceData:
@@ -522,6 +555,10 @@ def sample_numpyro_nuts(
522555
"parallel", and "vectorized".
523556
postprocessing_backend : Optional[str]
524557
Specify how postprocessing should be computed. gpu or cpu
558+
postprocessing_chunks: Optional[int], default None
559+
Specify the number of chunks the postprocessing should be computed in. More
560+
chunks reduces memory usage at the cost of losing some vectorization, None
561+
uses jax.vmap
525562
idata_kwargs : dict, optional
526563
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
527564
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -622,8 +659,8 @@ def sample_numpyro_nuts(
622659

623660
print("Transforming variables...", file=sys.stdout)
624661
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
625-
result = jax.vmap(jax.vmap(jax_fn))(
626-
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
662+
result = _postprocess_samples(
663+
jax_fn, raw_mcmc_samples, postprocessing_backend, num_chunks=postprocessing_chunks
627664
)
628665
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
629666

@@ -639,7 +676,10 @@ def sample_numpyro_nuts(
639676
tic5 = datetime.now()
640677
print("Computing Log Likelihood...", file=sys.stdout)
641678
log_likelihood = _get_log_likelihood(
642-
model, raw_mcmc_samples, backend=postprocessing_backend
679+
model,
680+
raw_mcmc_samples,
681+
backend=postprocessing_backend,
682+
num_chunks=postprocessing_chunks,
643683
)
644684
tic6 = datetime.now()
645685
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)

pymc/tests/sampling/test_jax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def test_old_import_route():
5555
),
5656
],
5757
)
58-
def test_transform_samples(sampler, postprocessing_backend, chains):
58+
@pytest.mark.parametrize("postprocessing_chunks", [None, 10])
59+
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_chunks):
5960
pytensor.config.on_opt_error = "raise"
6061
np.random.seed(13244)
6162

@@ -71,6 +72,7 @@ def test_transform_samples(sampler, postprocessing_backend, chains):
7172
random_seed=1322,
7273
keep_untransformed=True,
7374
postprocessing_backend=postprocessing_backend,
75+
postprocessing_chunks=postprocessing_chunks,
7476
)
7577

7678
log_vals = trace.posterior["sigma_log__"].values

0 commit comments

Comments
 (0)