diff --git a/pymc/initial_point.py b/pymc/initial_point.py index 241409f68..da55fc381 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -26,6 +26,7 @@ from pymc.logprob.transforms import Transform from pymc.pytensorf import ( + SeedSequenceSeed, compile, find_rng_nodes, replace_rng_nodes, @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain( overrides: StartDict | Sequence[StartDict | None] | None, jitter_rvs: set[TensorVariable] | None = None, chains: int, -) -> list[Callable]: +) -> list[Callable[[SeedSequenceSeed], PointType]]: """Create an initial point function for each chain, as defined by initvals. If a single initval dictionary is passed, the function is replicated for each @@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain( Random variable tensors for which U(-1, 1) jitter shall be applied. (To the transformed space if applicable.) + Returns + ------- + ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]] + list of functions that return initial points for each chain. + Raises ------ ValueError @@ -124,7 +130,7 @@ def make_initial_point_fn( jitter_rvs: set[TensorVariable] | None = None, default_strategy: str = "support_point", return_transformed: bool = True, -) -> Callable: +) -> Callable[[SeedSequenceSeed], PointType]: """Create seeded function that computes initial values for all free model variables. Parameters @@ -138,6 +144,10 @@ def make_initial_point_fn( Initial value (strategies) to use instead of what's specified in `Model.initial_values`. return_transformed : bool If `True` the returned variables will correspond to transformed initial values. + + Returns + ------- + initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]] """ sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) initval_strats = { diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 43e1baa87..89adbf233 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -18,6 +18,7 @@ from collections.abc import Callable, Sequence from datetime import datetime from functools import partial +from types import ModuleType from typing import Any, Literal import arviz as az @@ -28,6 +29,7 @@ from arviz.data.base import make_attrs from jax.lax import scan +from numpy.typing import ArrayLike from pytensor.compile import SharedVariable, Supervisor, mode from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph @@ -121,7 +123,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl def get_jaxified_graph( inputs: list[TensorVariable] | None = None, outputs: list[TensorVariable] | None = None, -) -> list[TensorVariable]: +) -> Callable[[list[TensorVariable]], list[TensorVariable]]: """Compile a PyTensor graph into an optimized JAX function.""" graph = _replace_shared_variables(outputs) if outputs is not None else None @@ -144,13 +146,13 @@ def get_jaxified_graph( return jax_funcify(fgraph) -def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: +def get_jaxified_logp(model: Model, negative_logp: bool = True) -> Callable[[ArrayLike], jax.Array]: model_logp = model.logp() if not negative_logp: model_logp = -model_logp logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) - def logp_fn_wrap(x): + def logp_fn_wrap(x: ArrayLike) -> jax.Array: return logp_fn(*x)[0] return logp_fn_wrap @@ -211,10 +213,16 @@ def _get_batched_jittered_initial_points( chains: int, initvals: StartDict | Sequence[StartDict | None] | None, random_seed: RandomSeed, + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, jitter: bool = True, jitter_max_retries: int = 10, ) -> np.ndarray | list[np.ndarray]: - """Get jittered initial point in format expected by NumPyro MCMC kernel. + """Get jittered initial point in format expected by Jax MCMC kernel. + + Parameters + ---------- + logp_fn : Callable[Sequence[np.ndarray]], np.ndarray] + Jaxified logp function Returns ------- @@ -222,12 +230,26 @@ def _get_batched_jittered_initial_points( list with one item per variable and number of chains as batch dimension. Each item has shape `(chains, *var.shape)` """ + if logp_fn is None: + eval_logp_initial_point = None + + else: + + def eval_logp_initial_point(point: dict[str, np.ndarray]) -> jax.Array: + """Wrap logp_fn to conform to _init_jitter logic. + + Wraps jaxified logp function to accept a dict of + {model_variable: np.array} key:value pairs. + """ + return logp_fn(point.values()) + initial_points = _init_jitter( model, initvals, seeds=_get_seeds_per_chain(random_seed, chains), jitter=jitter, jitter_max_retries=jitter_max_retries, + logp_fn=eval_logp_initial_point, ) initial_points_values = [list(initial_point.values()) for initial_point in initial_points] if chains == 1: @@ -236,7 +258,7 @@ def _get_batched_jittered_initial_points( def _blackjax_inference_loop( - seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs + seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs ): import blackjax @@ -252,13 +274,13 @@ def _blackjax_inference_loop( adapt = blackjax.window_adaptation( algorithm=algorithm, - logdensity_fn=logprob_fn, + logdensity_fn=logp_fn, target_acceptance_rate=target_accept, adaptation_info_fn=get_filter_adapt_info_fn(), **adaptation_kwargs, ) (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) - kernel = algorithm(logprob_fn, **tuned_params).step + kernel = algorithm(logp_fn, **tuned_params).step def _one_step(state, xs): _, rng_key = xs @@ -289,67 +311,51 @@ def _sample_blackjax_nuts( tune: int, draws: int, chains: int, - chain_method: str | None, + chain_method: Literal["parallel", "vectorized"], progressbar: bool, random_seed: int, - initial_points, - nuts_kwargs, -) -> az.InferenceData: + initial_points: np.ndarray | list[np.ndarray], + nuts_kwargs: dict[str, Any], + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, +) -> tuple[Any, dict[str, Any], ModuleType]: """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. Parameters ---------- - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. - tune : int, default 1000 + model : Model + Model to sample from. The model needs to have free random variables. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. + tune : int Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. - chains : int, default 4 + draws : int + The number of samples to draw. The number of tuned samples are discarded by default. + chains : int The number of chains to sample. - target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher - values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, RandomState or Generator, optional + chain_method : "parallel" or "vectorized" + Specify how samples should be drawn. + progressbar : bool + Whether to show progressbar or not during sampling. + random_seed : int, RandomState or Generator Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional - Initial values for random variables provided as a dictionary (or sequence of - dictionaries) mapping the random variable (by name or reference) to desired - starting values. - jitter: bool, default True - If True, add jitter to initial points. - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. - var_names : sequence of str, optional - Names of variables for which to compute the posterior samples. Defaults to all - variables in the posterior. - keep_untransformed : bool, default False - Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". - postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, - Specify how postprocessing should be computed. gpu or cpu - postprocessing_vectorize: Literal["vmap", "scan"], default "scan" - How to vectorize the postprocessing: vmap or sequential scan - idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as - value for the ``log_likelihood`` key to indicate that the pointwise log - likelihood should not be included in the returned object. Values for - ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from - the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and - ``dims`` are provided, they are used to update the inferred dictionaries. + initial_points : np.ndarray or list[np.ndarray] + Initial point(s) for sampler to begin sampling from. + nuts_kwargs : dict + Keyword arguments for the blackjax nuts sampler + logp_fn : Callable[[ArrayLike], jax.Array], optional, default None + jaxified logp function. If not passed in it will be created anew. Returns ------- - InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together - with their respective sample stats and pointwise log likeihood values (unless - skipped with ``idata_kwargs``). + raw_mcmc_samples + Datastructure containing raw mcmc samples + sample_stats : dict[str, Any] + Dictionary containing sample stats + blackjax : ModuleType["blackjax"] """ import blackjax @@ -366,7 +372,8 @@ def _sample_blackjax_nuts( if chains == 1: initial_points = [np.stack(init_state) for init_state in zip(initial_points)] - logprob_fn = get_jaxified_logp(model) + if logp_fn is None: + logp_fn = get_jaxified_logp(model) seed = jax.random.PRNGKey(random_seed) keys = jax.random.split(seed, chains) @@ -374,7 +381,7 @@ def _sample_blackjax_nuts( nuts_kwargs["progress_bar"] = progressbar get_posterior_samples = partial( _blackjax_inference_loop, - logprob_fn=logprob_fn, + logp_fn=logp_fn, tune=tune, draws=draws, target_accept=target_accept, @@ -386,7 +393,7 @@ def _sample_blackjax_nuts( # Adopted from arviz numpyro extractor -def _numpyro_stats_to_dict(posterior): +def _numpyro_stats_to_dict(posterior) -> dict[str, Any]: """Extract sample_stats from NumPyro posterior.""" rename_key = { "potential_energy": "lp", @@ -412,17 +419,58 @@ def _sample_numpyro_nuts( tune: int, draws: int, chains: int, - chain_method: str | None, + chain_method: Literal["parallel", "vectorized"], progressbar: bool, random_seed: int, - initial_points, + initial_points: np.ndarray | list[np.ndarray], nuts_kwargs: dict[str, Any], -): + logp_fn: Callable[[ArrayLike], jax.Array] | None = None, +) -> tuple[Any, dict[str, Any], ModuleType]: + """ + Draw samples from the posterior using the NUTS method from the ``numpyro`` library. + + Parameters + ---------- + model : Model + Model to sample from. The model needs to have free random variables. + target_accept : float in [0, 1]. + The step size is tuned such that we approximate this acceptance rate. Higher + values like 0.9 or 0.95 often work better for problematic posteriors. + tune : int + Number of iterations to tune. Samplers adjust the step sizes, scalings or + similar during tuning. Tuning samples will be drawn in addition to the number + specified in the ``draws`` argument. + draws : int + The number of samples to draw. The number of tuned samples are discarded by default. + chains : int + The number of chains to sample. + chain_method : "parallel" or "vectorized" + Specify how samples should be drawn. + progressbar : bool + Whether to show progressbar or not during sampling. + random_seed : int, RandomState or Generator + Random seed used by the sampling steps. + initial_points : np.ndarray or list[np.ndarray] + Initial point(s) for sampler to begin sampling from. + nuts_kwargs : dict + Keyword arguments for the underlying numpyro nuts sampler + logp_fn : Callable[[ArrayLike], jax.Array], optional, default None + jaxified logp function. If not passed in it will be created anew. + + Returns + ------- + raw_mcmc_samples + Datastructure containing raw mcmc samples + sample_stats : dict[str, Any] + Dictionary containing sample stats + numpyro : ModuleType["numpyro"] + """ import numpyro from numpyro.infer import MCMC, NUTS - logp_fn = get_jaxified_logp(model, negative_logp=False) + if logp_fn is None: + logp_fn = get_jaxified_logp(model, negative_logp=False) nuts_kwargs.setdefault("adapt_step_size", True) nuts_kwargs.setdefault("adapt_mass_matrix", True) @@ -480,7 +528,7 @@ def sample_jax_nuts( nuts_kwargs: dict | None = None, progressbar: bool = True, keep_untransformed: bool = False, - chain_method: str = "parallel", + chain_method: Literal["parallel", "vectorized"] = "parallel", postprocessing_backend: Literal["cpu", "gpu"] | None = None, postprocessing_vectorize: Literal["vmap", "scan"] | None = None, postprocessing_chunks=None, @@ -526,7 +574,7 @@ def sample_jax_nuts( If True, display a progressbar while sampling keep_untransformed : bool, default False Include untransformed variables in the posterior samples. - chain_method : str, default "parallel" + chain_method : Literal["parallel", "vectorized"], default "parallel" Specify how samples should be drawn. The choices include "parallel", and "vectorized". postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None, @@ -590,6 +638,15 @@ def sample_jax_nuts( get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) ) + if nuts_sampler == "numpyro": + sampler_fn = _sample_numpyro_nuts + logp_fn = get_jaxified_logp(model, negative_logp=False) + elif nuts_sampler == "blackjax": + sampler_fn = _sample_blackjax_nuts + logp_fn = get_jaxified_logp(model) + else: + raise ValueError(f"{nuts_sampler=} not recognized") + (random_seed,) = _get_seeds_per_chain(random_seed, 1) initial_points = _get_batched_jittered_initial_points( @@ -598,15 +655,9 @@ def sample_jax_nuts( initvals=initvals, random_seed=random_seed, jitter=jitter, + logp_fn=logp_fn, ) - if nuts_sampler == "numpyro": - sampler_fn = _sample_numpyro_nuts - elif nuts_sampler == "blackjax": - sampler_fn = _sample_blackjax_nuts - else: - raise ValueError(f"{nuts_sampler=} not recognized") - tic1 = datetime.now() raw_mcmc_samples, sample_stats, library = sampler_fn( model=model, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index bc3e3475d..89ebd4497 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1338,7 +1338,7 @@ def _init_jitter( seeds: Sequence[int] | np.ndarray, jitter: bool, jitter_max_retries: int, - logp_dlogp_func=None, + logp_fn: Callable[[PointType], np.ndarray] | None = None, ) -> list[PointType]: """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. @@ -1353,11 +1353,14 @@ def _init_jitter( Whether to apply jitter or not. jitter_max_retries : int Maximum number of repeated attempts at initializing values (per chain). + logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray | jax.Array] | None + logp function that takes the output of initial point functions as input. + If None, will use the results of model.compile_logp(). Returns ------- - start : ``pymc.model.Point`` - Starting point for sampler + initial_points : list[dict[str, np.ndarray]] + List of starting points for the sampler """ ipfns = make_initial_point_fns_per_chain( model=model, @@ -1369,14 +1372,10 @@ def _init_jitter( if not jitter: return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] - model_logp_fn: Callable - if logp_dlogp_func is None: - model_logp_fn = model.compile_logp() + if logp_fn is None: + model_logp_fn: Callable[[PointType], np.ndarray] = model.compile_logp() else: - - def model_logp_fn(ip): - q, _ = DictToArrayBijection.map(ip) - return logp_dlogp_func([q], extra_vars={})[0] + model_logp_fn = logp_fn initial_points = [] for ipfn, seed in zip(ipfns, seeds): @@ -1501,13 +1500,18 @@ def init_nuts( logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs) logp_dlogp_func.trust_input = True + + def model_logp_fn(ip: PointType) -> np.ndarray: + q, _ = DictToArrayBijection.map(ip) + return logp_dlogp_func([q], extra_vars={})[0] + initial_points = _init_jitter( model, initvals, seeds=random_seed_list, jitter="jitter" in init, jitter_max_retries=jitter_max_retries, - logp_dlogp_func=logp_dlogp_func, + logp_fn=model_logp_fn, ) apoints = [DictToArrayBijection.map(point) for point in initial_points]