Skip to content
14 changes: 12 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pymc.logprob.transforms import Transform
from pymc.pytensorf import (
SeedSequenceSeed,
compile,
find_rng_nodes,
replace_rng_nodes,
Expand Down Expand Up @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
chains: int,
) -> list[Callable]:
) -> list[Callable[[SeedSequenceSeed], PointType]]:
"""Create an initial point function for each chain, as defined by initvals.

If a single initval dictionary is passed, the function is replicated for each
Expand All @@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain(
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Returns
-------
ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]]
list of functions that return initial points for each chain.

Raises
------
ValueError
Expand Down Expand Up @@ -124,7 +130,7 @@ def make_initial_point_fn(
jitter_rvs: set[TensorVariable] | None = None,
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
) -> Callable[[SeedSequenceSeed], PointType]:
"""Create seeded function that computes initial values for all free model variables.

Parameters
Expand All @@ -138,6 +144,10 @@ def make_initial_point_fn(
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
return_transformed : bool
If `True` the returned variables will correspond to transformed initial values.

Returns
-------
initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
"""
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
Expand Down
189 changes: 120 additions & 69 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections.abc import Callable, Sequence
from datetime import datetime
from functools import partial
from types import ModuleType
from typing import Any, Literal

import arviz as az
Expand All @@ -28,6 +29,7 @@

from arviz.data.base import make_attrs
from jax.lax import scan
from numpy.typing import ArrayLike
from pytensor.compile import SharedVariable, Supervisor, mode
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -121,7 +123,7 @@
def get_jaxified_graph(
inputs: list[TensorVariable] | None = None,
outputs: list[TensorVariable] | None = None,
) -> list[TensorVariable]:
) -> Callable[[list[TensorVariable]], list[TensorVariable]]:
"""Compile a PyTensor graph into an optimized JAX function."""
graph = _replace_shared_variables(outputs) if outputs is not None else None

Expand All @@ -144,13 +146,13 @@
return jax_funcify(fgraph)


def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
def get_jaxified_logp(model: Model, negative_logp: bool = True) -> Callable[[ArrayLike], jax.Array]:
model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

def logp_fn_wrap(x):
def logp_fn_wrap(x: ArrayLike) -> jax.Array:
return logp_fn(*x)[0]

return logp_fn_wrap
Expand Down Expand Up @@ -211,23 +213,43 @@
chains: int,
initvals: StartDict | Sequence[StartDict | None] | None,
random_seed: RandomSeed,
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
jitter: bool = True,
jitter_max_retries: int = 10,
) -> np.ndarray | list[np.ndarray]:
"""Get jittered initial point in format expected by NumPyro MCMC kernel.
"""Get jittered initial point in format expected by Jax MCMC kernel.

Parameters
----------
logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
Jaxified logp function

Returns
-------
out: list of ndarrays
list with one item per variable and number of chains as batch dimension.
Each item has shape `(chains, *var.shape)`
"""
if logp_fn is None:
eval_logp_initial_point = None

else:

def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array:
"""Wrap logp_fn to conform to _init_jitter logic.

Wraps jaxified logp function to accept a dict of
{model_variable: np.array} key:value pairs.
"""
return logp_fn(point.values())

initial_points = _init_jitter(
model,
initvals,
seeds=_get_seeds_per_chain(random_seed, chains),
jitter=jitter,
jitter_max_retries=jitter_max_retries,
logp_fn=eval_logp_initial_point,
)
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
if chains == 1:
Expand All @@ -236,7 +258,7 @@


def _blackjax_inference_loop(
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs
):
import blackjax

Expand All @@ -252,13 +274,13 @@

adapt = blackjax.window_adaptation(
algorithm=algorithm,
logdensity_fn=logprob_fn,
logdensity_fn=logp_fn,
target_acceptance_rate=target_accept,
adaptation_info_fn=get_filter_adapt_info_fn(),
**adaptation_kwargs,
)
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
kernel = algorithm(logprob_fn, **tuned_params).step
kernel = algorithm(logp_fn, **tuned_params).step

def _one_step(state, xs):
_, rng_key = xs
Expand Down Expand Up @@ -289,67 +311,51 @@
tune: int,
draws: int,
chains: int,
chain_method: str | None,
chain_method: Literal["parallel", "vectorized"],
progressbar: bool,
random_seed: int,
initial_points,
nuts_kwargs,
) -> az.InferenceData:
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs: dict[str, Any],
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
) -> tuple[Any, dict[str, Any], ModuleType]:
"""
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.

Parameters
----------
draws : int, default 1000
The number of samples to draw. The number of tuned samples are discarded by
default.
tune : int, default 1000
model : Model
Model to sample from. The model needs to have free random variables.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
tune : int
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument.
chains : int, default 4
draws : int
The number of samples to draw. The number of tuned samples are discarded by default.
chains : int
The number of chains to sample.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
random_seed : int, RandomState or Generator, optional
chain_method : "parallel" or "vectorized"
Specify how samples should be drawn.
progressbar : bool
Whether to show progressbar or not during sampling.
random_seed : int, RandomState or Generator
Random seed used by the sampling steps.
initvals: StartDict or Sequence[Optional[StartDict]], optional
Initial values for random variables provided as a dictionary (or sequence of
dictionaries) mapping the random variable (by name or reference) to desired
starting values.
jitter: bool, default True
If True, add jitter to initial points.
model : Model, optional
Model to sample from. The model needs to have free random variables. When inside
a ``with`` model context, it defaults to that model, otherwise the model must be
passed explicitly.
var_names : sequence of str, optional
Names of variables for which to compute the posterior samples. Defaults to all
variables in the posterior.
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples. Defaults to False.
chain_method : str, default "parallel"
Specify how samples should be drawn. The choices include "parallel", and
"vectorized".
postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
Specify how postprocessing should be computed. gpu or cpu
postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
How to vectorize the postprocessing: vmap or sequential scan
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
value for the ``log_likelihood`` key to indicate that the pointwise log
likelihood should not be included in the returned object. Values for
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
``dims`` are provided, they are used to update the inferred dictionaries.
initial_points : np.ndarray or list[np.ndarray]
Initial point(s) for sampler to begin sampling from.
nuts_kwargs : dict
Keyword arguments for the blackjax nuts sampler
logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
jaxified logp function. If not passed in it will be created anew.

Returns
-------
InferenceData
ArviZ ``InferenceData`` object that contains the posterior samples, together
with their respective sample stats and pointwise log likeihood values (unless
skipped with ``idata_kwargs``).
raw_mcmc_samples
Datastructure containing raw mcmc samples
sample_stats : dict[str, Any]
Dictionary containing sample stats
blackjax : ModuleType["blackjax"]
"""
import blackjax

Expand All @@ -366,15 +372,16 @@
if chains == 1:
initial_points = [np.stack(init_state) for init_state in zip(initial_points)]

logprob_fn = get_jaxified_logp(model)
if logp_fn is None:
logp_fn = get_jaxified_logp(model)

seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)

nuts_kwargs["progress_bar"] = progressbar
get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
logp_fn=logp_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
Expand All @@ -386,7 +393,7 @@


# Adopted from arviz numpyro extractor
def _numpyro_stats_to_dict(posterior):
def _numpyro_stats_to_dict(posterior) -> dict[str, Any]:
"""Extract sample_stats from NumPyro posterior."""
rename_key = {
"potential_energy": "lp",
Expand All @@ -412,17 +419,58 @@
tune: int,
draws: int,
chains: int,
chain_method: str | None,
chain_method: Literal["parallel", "vectorized"],
progressbar: bool,
random_seed: int,
initial_points,
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs: dict[str, Any],
):
logp_fn: Callable[[ArrayLike], jax.Array] | None = None,
) -> tuple[Any, dict[str, Any], ModuleType]:
"""
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.

Parameters
----------
model : Model
Model to sample from. The model needs to have free random variables.
target_accept : float in [0, 1].
The step size is tuned such that we approximate this acceptance rate. Higher
values like 0.9 or 0.95 often work better for problematic posteriors.
tune : int
Number of iterations to tune. Samplers adjust the step sizes, scalings or
similar during tuning. Tuning samples will be drawn in addition to the number
specified in the ``draws`` argument.
draws : int
The number of samples to draw. The number of tuned samples are discarded by default.
chains : int
The number of chains to sample.
chain_method : "parallel" or "vectorized"
Specify how samples should be drawn.
progressbar : bool
Whether to show progressbar or not during sampling.
random_seed : int, RandomState or Generator
Random seed used by the sampling steps.
initial_points : np.ndarray or list[np.ndarray]
Initial point(s) for sampler to begin sampling from.
nuts_kwargs : dict
Keyword arguments for the underlying numpyro nuts sampler
logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
jaxified logp function. If not passed in it will be created anew.

Returns
-------
raw_mcmc_samples
Datastructure containing raw mcmc samples
sample_stats : dict[str, Any]
Dictionary containing sample stats
numpyro : ModuleType["numpyro"]
"""
import numpyro

from numpyro.infer import MCMC, NUTS

logp_fn = get_jaxified_logp(model, negative_logp=False)
if logp_fn is None:
logp_fn = get_jaxified_logp(model, negative_logp=False)

nuts_kwargs.setdefault("adapt_step_size", True)
nuts_kwargs.setdefault("adapt_mass_matrix", True)
Expand Down Expand Up @@ -480,7 +528,7 @@
nuts_kwargs: dict | None = None,
progressbar: bool = True,
keep_untransformed: bool = False,
chain_method: str = "parallel",
chain_method: Literal["parallel", "vectorized"] = "parallel",
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
postprocessing_vectorize: Literal["vmap", "scan"] | None = None,
postprocessing_chunks=None,
Expand Down Expand Up @@ -526,7 +574,7 @@
If True, display a progressbar while sampling
keep_untransformed : bool, default False
Include untransformed variables in the posterior samples.
chain_method : str, default "parallel"
chain_method : Literal["parallel", "vectorized"], default "parallel"
Specify how samples should be drawn. The choices include "parallel", and
"vectorized".
postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
Expand Down Expand Up @@ -590,6 +638,15 @@
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
logp_fn = get_jaxified_logp(model, negative_logp=False)
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
logp_fn = get_jaxified_logp(model)
else:

Check warning on line 647 in pymc/sampling/jax.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/jax.py#L647

Added line #L647 was not covered by tests
raise ValueError(f"{nuts_sampler=} not recognized")

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

initial_points = _get_batched_jittered_initial_points(
Expand All @@ -598,15 +655,9 @@
initvals=initvals,
random_seed=random_seed,
jitter=jitter,
logp_fn=logp_fn,
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
else:
raise ValueError(f"{nuts_sampler=} not recognized")

tic1 = datetime.now()
raw_mcmc_samples, sample_stats, library = sampler_fn(
model=model,
Expand Down
Loading
Loading