1818from collections .abc import Callable , Sequence
1919from datetime import datetime
2020from functools import partial
21- from typing import Any , Literal
21+ from types import ModuleType
22+ from typing import TYPE_CHECKING , Any , Literal
2223
2324import arviz as az
2425import jax
6970 "sample_numpyro_nuts" ,
7071)
7172
73+ if TYPE_CHECKING :
74+ from numpyro .infer import MCMC
75+
7276
7377@jax_funcify .register (Assert )
7478@jax_funcify .register (CheckParameterValue )
@@ -310,50 +314,48 @@ def _sample_blackjax_nuts(
310314 tune : int ,
311315 draws : int ,
312316 chains : int ,
313- chain_method : str | None ,
317+ chain_method : Literal [ "parallel" , "vectorized" ] ,
314318 progressbar : bool ,
315319 random_seed : int ,
316320 initial_points : np .ndarray | list [np .ndarray ],
317321 nuts_kwargs ,
318- logp_fn : Callable [[Sequence [ np . ndarray ]], np . ndarray ] | None = None ,
319- ) -> az . InferenceData :
322+ logp_fn : Callable [[ArrayLike ], jax . Array ] | None = None ,
323+ ) -> tuple [ Any , dict [ str , Any ], ModuleType ] :
320324 """
321325 Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
322326
323327 Parameters
324328 ----------
325- draws : int, default 1000
326- The number of samples to draw. The number of tuned samples are discarded by
327- default.
329+ model : Model, optional
330+ Model to sample from. The model needs to have free random variables. When inside
331+ a ``with`` model context, it defaults to that model, otherwise the model must be
332+ passed explicitly.
333+ target_accept : float in [0, 1].
334+ The step size is tuned such that we approximate this acceptance rate. Higher
335+ values like 0.9 or 0.95 often work better for problematic posteriors.
328336 tune : int, default 1000
329337 Number of iterations to tune. Samplers adjust the step sizes, scalings or
330338 similar during tuning. Tuning samples will be drawn in addition to the number
331339 specified in the ``draws`` argument.
340+ draws : int, default 1000
341+ The number of samples to draw. The number of tuned samples are discarded by
342+ default.
332343 chains : int, default 4
333344 The number of chains to sample.
334- target_accept : float in [0, 1].
335- The step size is tuned such that we approximate this acceptance rate. Higher
336- values like 0.9 or 0.95 often work better for problematic posteriors.
345+ chain_method : str, default "parallel"
346+ Specify how samples should be drawn. The choices include "parallel", and
347+ "vectorized".
348+ progressbar : bool
349+ Whether to show progressbar or not during sampling.
337350 random_seed : int, RandomState or Generator, optional
338351 Random seed used by the sampling steps.
339- initvals: StartDict or Sequence[Optional[StartDict]], optional
340- Initial values for random variables provided as a dictionary (or sequence of
341- dictionaries) mapping the random variable (by name or reference) to desired
342- starting values.
343- jitter: bool, default True
344- If True, add jitter to initial points.
345- model : Model, optional
346- Model to sample from. The model needs to have free random variables. When inside
347- a ``with`` model context, it defaults to that model, otherwise the model must be
348- passed explicitly.
352+ initial_points : np.ndarray | list[np.ndarray]
353+ Initial point(s) for sampler to begin sampling from.
349354 var_names : sequence of str, optional
350355 Names of variables for which to compute the posterior samples. Defaults to all
351356 variables in the posterior.
352357 keep_untransformed : bool, default False
353358 Include untransformed variables in the posterior samples. Defaults to False.
354- chain_method : str, default "parallel"
355- Specify how samples should be drawn. The choices include "parallel", and
356- "vectorized".
357359 postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
358360 Specify how postprocessing should be computed. gpu or cpu
359361 postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
@@ -365,13 +367,17 @@ def _sample_blackjax_nuts(
365367 ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
366368 the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
367369 ``dims`` are provided, they are used to update the inferred dictionaries.
370+ logp_fn : Callable[[ArrayLike], jax.Array] | None:
371+ jaxified logp function. If not passed in it will compute it here.
368372
369373 Returns
370374 -------
371- InferenceData
372- ArviZ ``InferenceData`` object that contains the posterior samples, together
373- with their respective sample stats and pointwise log likeihood values (unless
374- skipped with ``idata_kwargs``).
375+ Tuple containing:
376+ raw_mcmc_samples
377+ Datastructure containing raw mcmc samples
378+ sample_stats : dict[str, Any]
379+ Dictionary containing sample stats
380+ Module("blackjax")
375381 """
376382 import blackjax
377383
@@ -409,7 +415,7 @@ def _sample_blackjax_nuts(
409415
410416
411417# Adopted from arviz numpyro extractor
412- def _numpyro_stats_to_dict (posterior ) :
418+ def _numpyro_stats_to_dict (posterior : MCMC ) -> dict [ str , Any ] :
413419 """Extract sample_stats from NumPyro posterior."""
414420 rename_key = {
415421 "potential_energy" : "lp" ,
@@ -440,8 +446,50 @@ def _sample_numpyro_nuts(
440446 random_seed : int ,
441447 initial_points : np .ndarray | list [np .ndarray ],
442448 nuts_kwargs : dict [str , Any ],
443- logp_fn : Callable | None = None ,
444- ):
449+ logp_fn : Callable [[ArrayLike ], jax .Array ] | None = None ,
450+ ) -> tuple [Any , dict [str , Any ], ModuleType ]:
451+ """
452+ Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
453+
454+ Parameters
455+ ----------
456+ model : Model, optional
457+ Model to sample from. The model needs to have free random variables. When inside
458+ a ``with`` model context, it defaults to that model, otherwise the model must be
459+ passed explicitly.
460+ target_accept : float in [0, 1].
461+ The step size is tuned such that we approximate this acceptance rate. Higher
462+ values like 0.9 or 0.95 often work better for problematic posteriors.
463+ tune : int, default 1000
464+ Number of iterations to tune. Samplers adjust the step sizes, scalings or
465+ similar during tuning. Tuning samples will be drawn in addition to the number
466+ specified in the ``draws`` argument.
467+ draws : int, default 1000
468+ The number of samples to draw. The number of tuned samples are discarded by
469+ default.
470+ chains : int, default 4
471+ The number of chains to sample.
472+ chain_method : str, default "parallel"
473+ Specify how samples should be drawn. The choices include "parallel", and
474+ "vectorized".
475+ progressbar : bool
476+ Whether to show progressbar or not during sampling.
477+ random_seed : int, RandomState or Generator, optional
478+ Random seed used by the sampling steps.
479+ initial_points : np.ndarray | list[np.ndarray]
480+ Initial point(s) for sampler to begin sampling from.
481+ logp_fn : Callable[[ArrayLike], jax.Array] | None:
482+ jaxified logp function. If not passed in it will compute it here.
483+
484+ Returns
485+ -------
486+ Tuple containing:
487+ raw_mcmc_samples
488+ Datastructure containing raw mcmc samples
489+ sample_stats : dict[str, Any]
490+ Dictionary containing sample stats
491+ Module("numpyro")
492+ """
445493 import numpyro
446494
447495 from numpyro .infer import MCMC , NUTS
@@ -505,7 +553,7 @@ def sample_jax_nuts(
505553 nuts_kwargs : dict | None = None ,
506554 progressbar : bool = True ,
507555 keep_untransformed : bool = False ,
508- chain_method : str = "parallel" ,
556+ chain_method : Literal [ "parallel" , "vectorized" ] = "parallel" ,
509557 postprocessing_backend : Literal ["cpu" , "gpu" ] | None = None ,
510558 postprocessing_vectorize : Literal ["vmap" , "scan" ] | None = None ,
511559 postprocessing_chunks = None ,
@@ -551,7 +599,7 @@ def sample_jax_nuts(
551599 If True, display a progressbar while sampling
552600 keep_untransformed : bool, default False
553601 Include untransformed variables in the posterior samples.
554- chain_method : str , default "parallel"
602+ chain_method : Literal["parallel", "vectorized"] , default "parallel"
555603 Specify how samples should be drawn. The choices include "parallel", and
556604 "vectorized".
557605 postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
0 commit comments