Skip to content

Commit 2e9d7db

Browse files
author
Goose
committed
improved docstrings & type annotations
1 parent 85996a1 commit 2e9d7db

File tree

1 file changed

+80
-32
lines changed

1 file changed

+80
-32
lines changed

pymc/sampling/jax.py

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from collections.abc import Callable, Sequence
1919
from datetime import datetime
2020
from functools import partial
21-
from typing import Any, Literal
21+
from types import ModuleType
22+
from typing import TYPE_CHECKING, Any, Literal
2223

2324
import arviz as az
2425
import jax
@@ -69,6 +70,9 @@
6970
"sample_numpyro_nuts",
7071
)
7172

73+
if TYPE_CHECKING:
74+
from numpyro.infer import MCMC
75+
7276

7377
@jax_funcify.register(Assert)
7478
@jax_funcify.register(CheckParameterValue)
@@ -310,50 +314,48 @@ def _sample_blackjax_nuts(
310314
tune: int,
311315
draws: int,
312316
chains: int,
313-
chain_method: str | None,
317+
chain_method: Literal["parallel", "vectorized"],
314318
progressbar: bool,
315319
random_seed: int,
316320
initial_points: np.ndarray | list[np.ndarray],
317321
nuts_kwargs,
318-
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None,
319-
) -> az.InferenceData:
322+
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
323+
) -> tuple[Any, dict[str, Any], ModuleType]:
320324
"""
321325
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
322326
323327
Parameters
324328
----------
325-
draws : int, default 1000
326-
The number of samples to draw. The number of tuned samples are discarded by
327-
default.
329+
model : Model, optional
330+
Model to sample from. The model needs to have free random variables. When inside
331+
a ``with`` model context, it defaults to that model, otherwise the model must be
332+
passed explicitly.
333+
target_accept : float in [0, 1].
334+
The step size is tuned such that we approximate this acceptance rate. Higher
335+
values like 0.9 or 0.95 often work better for problematic posteriors.
328336
tune : int, default 1000
329337
Number of iterations to tune. Samplers adjust the step sizes, scalings or
330338
similar during tuning. Tuning samples will be drawn in addition to the number
331339
specified in the ``draws`` argument.
340+
draws : int, default 1000
341+
The number of samples to draw. The number of tuned samples are discarded by
342+
default.
332343
chains : int, default 4
333344
The number of chains to sample.
334-
target_accept : float in [0, 1].
335-
The step size is tuned such that we approximate this acceptance rate. Higher
336-
values like 0.9 or 0.95 often work better for problematic posteriors.
345+
chain_method : str, default "parallel"
346+
Specify how samples should be drawn. The choices include "parallel", and
347+
"vectorized".
348+
progressbar : bool
349+
Whether to show progressbar or not during sampling.
337350
random_seed : int, RandomState or Generator, optional
338351
Random seed used by the sampling steps.
339-
initvals: StartDict or Sequence[Optional[StartDict]], optional
340-
Initial values for random variables provided as a dictionary (or sequence of
341-
dictionaries) mapping the random variable (by name or reference) to desired
342-
starting values.
343-
jitter: bool, default True
344-
If True, add jitter to initial points.
345-
model : Model, optional
346-
Model to sample from. The model needs to have free random variables. When inside
347-
a ``with`` model context, it defaults to that model, otherwise the model must be
348-
passed explicitly.
352+
initial_points : np.ndarray | list[np.ndarray]
353+
Initial point(s) for sampler to begin sampling from.
349354
var_names : sequence of str, optional
350355
Names of variables for which to compute the posterior samples. Defaults to all
351356
variables in the posterior.
352357
keep_untransformed : bool, default False
353358
Include untransformed variables in the posterior samples. Defaults to False.
354-
chain_method : str, default "parallel"
355-
Specify how samples should be drawn. The choices include "parallel", and
356-
"vectorized".
357359
postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
358360
Specify how postprocessing should be computed. gpu or cpu
359361
postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
@@ -365,13 +367,17 @@ def _sample_blackjax_nuts(
365367
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
366368
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
367369
``dims`` are provided, they are used to update the inferred dictionaries.
370+
logp_fn : Callable[[ArrayLike], jax.Array] | None:
371+
jaxified logp function. If not passed in it will compute it here.
368372
369373
Returns
370374
-------
371-
InferenceData
372-
ArviZ ``InferenceData`` object that contains the posterior samples, together
373-
with their respective sample stats and pointwise log likeihood values (unless
374-
skipped with ``idata_kwargs``).
375+
Tuple containing:
376+
raw_mcmc_samples
377+
Datastructure containing raw mcmc samples
378+
sample_stats : dict[str, Any]
379+
Dictionary containing sample stats
380+
Module("blackjax")
375381
"""
376382
import blackjax
377383

@@ -409,7 +415,7 @@ def _sample_blackjax_nuts(
409415

410416

411417
# Adopted from arviz numpyro extractor
412-
def _numpyro_stats_to_dict(posterior):
418+
def _numpyro_stats_to_dict(posterior: MCMC) -> dict[str, Any]:
413419
"""Extract sample_stats from NumPyro posterior."""
414420
rename_key = {
415421
"potential_energy": "lp",
@@ -440,8 +446,50 @@ def _sample_numpyro_nuts(
440446
random_seed: int,
441447
initial_points: np.ndarray | list[np.ndarray],
442448
nuts_kwargs: dict[str, Any],
443-
logp_fn: Callable | None = None,
444-
):
449+
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
450+
) -> tuple[Any, dict[str, Any], ModuleType]:
451+
"""
452+
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
453+
454+
Parameters
455+
----------
456+
model : Model, optional
457+
Model to sample from. The model needs to have free random variables. When inside
458+
a ``with`` model context, it defaults to that model, otherwise the model must be
459+
passed explicitly.
460+
target_accept : float in [0, 1].
461+
The step size is tuned such that we approximate this acceptance rate. Higher
462+
values like 0.9 or 0.95 often work better for problematic posteriors.
463+
tune : int, default 1000
464+
Number of iterations to tune. Samplers adjust the step sizes, scalings or
465+
similar during tuning. Tuning samples will be drawn in addition to the number
466+
specified in the ``draws`` argument.
467+
draws : int, default 1000
468+
The number of samples to draw. The number of tuned samples are discarded by
469+
default.
470+
chains : int, default 4
471+
The number of chains to sample.
472+
chain_method : str, default "parallel"
473+
Specify how samples should be drawn. The choices include "parallel", and
474+
"vectorized".
475+
progressbar : bool
476+
Whether to show progressbar or not during sampling.
477+
random_seed : int, RandomState or Generator, optional
478+
Random seed used by the sampling steps.
479+
initial_points : np.ndarray | list[np.ndarray]
480+
Initial point(s) for sampler to begin sampling from.
481+
logp_fn : Callable[[ArrayLike], jax.Array] | None:
482+
jaxified logp function. If not passed in it will compute it here.
483+
484+
Returns
485+
-------
486+
Tuple containing:
487+
raw_mcmc_samples
488+
Datastructure containing raw mcmc samples
489+
sample_stats : dict[str, Any]
490+
Dictionary containing sample stats
491+
Module("numpyro")
492+
"""
445493
import numpyro
446494

447495
from numpyro.infer import MCMC, NUTS
@@ -505,7 +553,7 @@ def sample_jax_nuts(
505553
nuts_kwargs: dict | None = None,
506554
progressbar: bool = True,
507555
keep_untransformed: bool = False,
508-
chain_method: str = "parallel",
556+
chain_method: Literal["parallel", "vectorized"] = "parallel",
509557
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
510558
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
511559
postprocessing_chunks=None,
@@ -551,7 +599,7 @@ def sample_jax_nuts(
551599
If True, display a progressbar while sampling
552600
keep_untransformed : bool, default False
553601
Include untransformed variables in the posterior samples.
554-
chain_method : str, default "parallel"
602+
chain_method : Literal["parallel", "vectorized"], default "parallel"
555603
Specify how samples should be drawn. The choices include "parallel", and
556604
"vectorized".
557605
postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,

0 commit comments

Comments
 (0)