21
21
import pytensor .tensor as at
22
22
23
23
from arviz .data .base import make_attrs
24
+ from jax .experimental .maps import SerialLoop , xmap
24
25
from pytensor .compile import SharedVariable , Supervisor , mode
25
26
from pytensor .graph .basic import graph_inputs
26
27
from pytensor .graph .fg import FunctionGraph
@@ -143,6 +144,27 @@ def _sample_stats_to_xarray(posterior):
143
144
return data
144
145
145
146
147
+ def _postprocess_samples (
148
+ jax_fn : List [TensorVariable ],
149
+ raw_mcmc_samples : List [TensorVariable ],
150
+ postprocessing_backend : str ,
151
+ num_chunks : Optional [int ] = None ,
152
+ ) -> List [TensorVariable ]:
153
+ if num_chunks is not None :
154
+ loop = xmap (
155
+ jax_fn ,
156
+ in_axes = ["chain" , "samples" , ...],
157
+ out_axes = ["chain" , "samples" , ...],
158
+ axis_resources = {"samples" : SerialLoop (num_chunks )},
159
+ )
160
+ f = xmap (loop , in_axes = [...], out_axes = [...])
161
+ return f (* jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ]))
162
+ else :
163
+ return jax .vmap (jax .vmap (jax_fn ))(
164
+ * jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ])
165
+ )
166
+
167
+
146
168
def _blackjax_stats_to_dict (sample_stats , potential_energy ) -> Dict :
147
169
"""Extract compatible stats from blackjax NUTS sampler
148
170
with PyMC/Arviz naming conventions.
@@ -177,11 +199,13 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
177
199
return converted_stats
178
200
179
201
180
- def _get_log_likelihood (model : Model , samples , backend = None ) -> Dict :
202
+ def _get_log_likelihood (
203
+ model : Model , samples , backend = None , num_chunks : Optional [int ] = None
204
+ ) -> Dict :
181
205
"""Compute log-likelihood for all observations"""
182
206
elemwise_logp = model .logp (model .observed_RVs , sum = False )
183
207
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = elemwise_logp )
184
- result = jax . vmap ( jax . vmap ( jax_fn ))( * jax . device_put ( samples , jax . devices ( backend )[ 0 ]) )
208
+ result = _postprocess_samples ( jax_fn , samples , backend , num_chunks = num_chunks )
185
209
return {v .name : r for v , r in zip (model .observed_RVs , result )}
186
210
187
211
@@ -275,6 +299,7 @@ def sample_blackjax_nuts(
275
299
keep_untransformed : bool = False ,
276
300
chain_method : str = "parallel" ,
277
301
postprocessing_backend : Optional [str ] = None ,
302
+ postprocessing_chunks : Optional [int ] = None ,
278
303
idata_kwargs : Optional [Dict [str , Any ]] = None ,
279
304
) -> az .InferenceData :
280
305
"""
@@ -314,6 +339,10 @@ def sample_blackjax_nuts(
314
339
"vectorized".
315
340
postprocessing_backend : str, optional
316
341
Specify how postprocessing should be computed. gpu or cpu
342
+ postprocessing_chunks: Optional[int], default None
343
+ Specify the number of chunks the postprocessing should be computed in. More
344
+ chunks reduces memory usage at the cost of losing some vectorization, None
345
+ uses jax.vmap
317
346
idata_kwargs : dict, optional
318
347
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
319
348
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -400,8 +429,8 @@ def sample_blackjax_nuts(
400
429
401
430
print ("Transforming variables..." , file = sys .stdout )
402
431
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
403
- result = jax . vmap ( jax . vmap ( jax_fn )) (
404
- * jax . device_put ( raw_mcmc_samples , jax . devices ( postprocessing_backend )[ 0 ])
432
+ result = _postprocess_samples (
433
+ jax_fn , raw_mcmc_samples , postprocessing_backend , num_chunks = postprocessing_chunks
405
434
)
406
435
mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
407
436
mcmc_stats = _blackjax_stats_to_dict (stats , potential_energy )
@@ -417,7 +446,10 @@ def sample_blackjax_nuts(
417
446
tic5 = datetime .now ()
418
447
print ("Computing Log Likelihood..." , file = sys .stdout )
419
448
log_likelihood = _get_log_likelihood (
420
- model , raw_mcmc_samples , backend = postprocessing_backend
449
+ model ,
450
+ raw_mcmc_samples ,
451
+ backend = postprocessing_backend ,
452
+ num_chunks = postprocessing_chunks ,
421
453
)
422
454
tic6 = datetime .now ()
423
455
print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
@@ -478,6 +510,7 @@ def sample_numpyro_nuts(
478
510
keep_untransformed : bool = False ,
479
511
chain_method : str = "parallel" ,
480
512
postprocessing_backend : Optional [str ] = None ,
513
+ postprocessing_chunks : Optional [int ] = None ,
481
514
idata_kwargs : Optional [Dict ] = None ,
482
515
nuts_kwargs : Optional [Dict ] = None ,
483
516
) -> az .InferenceData :
@@ -522,6 +555,10 @@ def sample_numpyro_nuts(
522
555
"parallel", and "vectorized".
523
556
postprocessing_backend : Optional[str]
524
557
Specify how postprocessing should be computed. gpu or cpu
558
+ postprocessing_chunks: Optional[int], default None
559
+ Specify the number of chunks the postprocessing should be computed in. More
560
+ chunks reduces memory usage at the cost of losing some vectorization, None
561
+ uses jax.vmap
525
562
idata_kwargs : dict, optional
526
563
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
527
564
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -622,8 +659,8 @@ def sample_numpyro_nuts(
622
659
623
660
print ("Transforming variables..." , file = sys .stdout )
624
661
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
625
- result = jax . vmap ( jax . vmap ( jax_fn )) (
626
- * jax . device_put ( raw_mcmc_samples , jax . devices ( postprocessing_backend )[ 0 ])
662
+ result = _postprocess_samples (
663
+ jax_fn , raw_mcmc_samples , postprocessing_backend , num_chunks = postprocessing_chunks
627
664
)
628
665
mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
629
666
@@ -639,7 +676,10 @@ def sample_numpyro_nuts(
639
676
tic5 = datetime .now ()
640
677
print ("Computing Log Likelihood..." , file = sys .stdout )
641
678
log_likelihood = _get_log_likelihood (
642
- model , raw_mcmc_samples , backend = postprocessing_backend
679
+ model ,
680
+ raw_mcmc_samples ,
681
+ backend = postprocessing_backend ,
682
+ num_chunks = postprocessing_chunks ,
643
683
)
644
684
tic6 = datetime .now ()
645
685
print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
0 commit comments