Skip to content
14 changes: 12 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pymc.logprob.transforms import Transform
from pymc.pytensorf import (
SeedSequenceSeed,
compile,
find_rng_nodes,
replace_rng_nodes,
Expand Down Expand Up @@ -67,7 +68,7 @@ def make_initial_point_fns_per_chain(
overrides: StartDict | Sequence[StartDict | None] | None,
jitter_rvs: set[TensorVariable] | None = None,
chains: int,
) -> list[Callable]:
) -> list[Callable[[SeedSequenceSeed], PointType]]:
"""Create an initial point function for each chain, as defined by initvals.

If a single initval dictionary is passed, the function is replicated for each
Expand All @@ -82,6 +83,11 @@ def make_initial_point_fns_per_chain(
Random variable tensors for which U(-1, 1) jitter shall be applied.
(To the transformed space if applicable.)

Returns
-------
ipfns : list[Callable[[SeedSequenceSeed], dict[str, np.ndarray]]]
list of functions that return initial points for each chain.

Raises
------
ValueError
Expand Down Expand Up @@ -124,7 +130,7 @@ def make_initial_point_fn(
jitter_rvs: set[TensorVariable] | None = None,
default_strategy: str = "support_point",
return_transformed: bool = True,
) -> Callable:
) -> Callable[[SeedSequenceSeed], PointType]:
"""Create seeded function that computes initial values for all free model variables.

Parameters
Expand All @@ -138,6 +144,10 @@ def make_initial_point_fn(
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
return_transformed : bool
If `True` the returned variables will correspond to transformed initial values.

Returns
-------
initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]
"""
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
initval_strats = {
Expand Down
4 changes: 2 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import types
import warnings

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from typing import (
Literal,
cast,
Expand Down Expand Up @@ -585,7 +585,7 @@ def compile_logp(
jacobian: bool = True,
sum: bool = True,
**compile_kwargs,
) -> PointFunc:
) -> Callable[[PointType], np.ndarray]:
"""Compiled log probability density function.

Parameters
Expand Down
67 changes: 48 additions & 19 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ def get_jaxified_graph(
return jax_funcify(fgraph)


def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
def get_jaxified_logp(
model: Model, negative_logp=True
) -> Callable[[Sequence[np.ndarray]], np.ndarray]:
model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

def logp_fn_wrap(x):
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct, it takes jax arrays and outputs jax arrays

Copy link
Contributor Author

@nataziel nataziel Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's 100% true. Checking with the interactive debugger confirms that the return type is jax.Array, but the initial point functions return a dict[str, np.ndarray], and we can successfully pass the .values() of that dict into the jaxified function. So it can seemingly accept anything that's coercible to an array. Maybe it's more correct to annotate it like this:

def logp_fn_wrap(x: ArrayLike) -> jax.Array:

ArrayLike is from numpy.typing: https://numpy.org/devdocs/reference/typing.html#numpy.typing.ArrayLike

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just pushed a commit to improve this, it's a bit tricky to annotate at the interface with _init_jitter given that jax is an optional dependency. I've left the type annotation as returning a np.ndarray but included that it may return a jax.Array in the docstring.

return logp_fn(*x)[0]

return logp_fn_wrap
Expand Down Expand Up @@ -211,23 +213,43 @@ def _get_batched_jittered_initial_points(
chains: int,
initvals: StartDict | Sequence[StartDict | None] | None,
random_seed: RandomSeed,
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None,
jitter: bool = True,
jitter_max_retries: int = 10,
) -> np.ndarray | list[np.ndarray]:
"""Get jittered initial point in format expected by NumPyro MCMC kernel.
"""Get jittered initial point in format expected by Jax MCMC kernel.

Parameters
----------
logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
Jaxified logp function

Returns
-------
out: list of ndarrays
out: list[np.ndarray]
list with one item per variable and number of chains as batch dimension.
Each item has shape `(chains, *var.shape)`
"""
if logp_fn is None:
eval_logp_initial_point = None

else:

def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray:
"""Wrap logp_fn to conform to _init_jitter logic.

Wraps jaxified logp function to accept a dict of
{model_variable: np.array} key:value pairs.
"""
return logp_fn(point.values())

initial_points = _init_jitter(
model,
initvals,
seeds=_get_seeds_per_chain(random_seed, chains),
jitter=jitter,
jitter_max_retries=jitter_max_retries,
logp_fn=eval_logp_initial_point,
)
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
if chains == 1:
Expand All @@ -236,7 +258,7 @@ def _get_batched_jittered_initial_points(


def _blackjax_inference_loop(
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs
):
import blackjax

Expand All @@ -252,13 +274,13 @@ def _blackjax_inference_loop(

adapt = blackjax.window_adaptation(
algorithm=algorithm,
logdensity_fn=logprob_fn,
logdensity_fn=logp_fn,
target_acceptance_rate=target_accept,
adaptation_info_fn=get_filter_adapt_info_fn(),
**adaptation_kwargs,
)
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
kernel = algorithm(logprob_fn, **tuned_params).step
kernel = algorithm(logp_fn, **tuned_params).step

def _one_step(state, xs):
_, rng_key = xs
Expand Down Expand Up @@ -292,8 +314,9 @@ def _sample_blackjax_nuts(
chain_method: str | None,
progressbar: bool,
random_seed: int,
initial_points,
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs,
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None,
) -> az.InferenceData:
"""
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
Expand Down Expand Up @@ -366,15 +389,16 @@ def _sample_blackjax_nuts(
if chains == 1:
initial_points = [np.stack(init_state) for init_state in zip(initial_points)]

logprob_fn = get_jaxified_logp(model)
if logp_fn is None:
logp_fn = get_jaxified_logp(model)

seed = jax.random.PRNGKey(random_seed)
keys = jax.random.split(seed, chains)

nuts_kwargs["progress_bar"] = progressbar
get_posterior_samples = partial(
_blackjax_inference_loop,
logprob_fn=logprob_fn,
logp_fn=logp_fn,
tune=tune,
draws=draws,
target_accept=target_accept,
Expand Down Expand Up @@ -415,14 +439,16 @@ def _sample_numpyro_nuts(
chain_method: str | None,
progressbar: bool,
random_seed: int,
initial_points,
initial_points: np.ndarray | list[np.ndarray],
nuts_kwargs: dict[str, Any],
logp_fn: Callable | None = None,
):
import numpyro

from numpyro.infer import MCMC, NUTS

logp_fn = get_jaxified_logp(model, negative_logp=False)
if logp_fn is None:
logp_fn = get_jaxified_logp(model, negative_logp=False)

nuts_kwargs.setdefault("adapt_step_size", True)
nuts_kwargs.setdefault("adapt_mass_matrix", True)
Expand Down Expand Up @@ -590,6 +616,15 @@ def sample_jax_nuts(
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
logp_fn = get_jaxified_logp(model, negative_logp=False)
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
logp_fn = get_jaxified_logp(model)
else:
raise ValueError(f"{nuts_sampler=} not recognized")

(random_seed,) = _get_seeds_per_chain(random_seed, 1)

initial_points = _get_batched_jittered_initial_points(
Expand All @@ -598,15 +633,9 @@ def sample_jax_nuts(
initvals=initvals,
random_seed=random_seed,
jitter=jitter,
logp_fn=logp_fn,
)

if nuts_sampler == "numpyro":
sampler_fn = _sample_numpyro_nuts
elif nuts_sampler == "blackjax":
sampler_fn = _sample_blackjax_nuts
else:
raise ValueError(f"{nuts_sampler=} not recognized")

tic1 = datetime.now()
raw_mcmc_samples, sample_stats, library = sampler_fn(
model=model,
Expand Down
24 changes: 14 additions & 10 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ def _init_jitter(
seeds: Sequence[int] | np.ndarray,
jitter: bool,
jitter_max_retries: int,
logp_dlogp_func=None,
logp_fn: Callable[[PointType], np.ndarray] | None = None,
) -> list[PointType]:
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.

Expand All @@ -1353,11 +1353,14 @@ def _init_jitter(
Whether to apply jitter or not.
jitter_max_retries : int
Maximum number of repeated attempts at initializing values (per chain).
logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] | None
Jaxified logp function that takes the output of the initial point functions as input.
If None, will use the results of model.compile_logp().

Returns
-------
start : ``pymc.model.Point``
Starting point for sampler
initial_points : list[dict[str, np.ndarray]]
List of starting points for the sampler
"""
ipfns = make_initial_point_fns_per_chain(
model=model,
Expand All @@ -1369,14 +1372,10 @@ def _init_jitter(
if not jitter:
return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]

model_logp_fn: Callable
if logp_dlogp_func is None:
if logp_fn is None:
model_logp_fn = model.compile_logp()
else:

def model_logp_fn(ip):
q, _ = DictToArrayBijection.map(ip)
return logp_dlogp_func([q], extra_vars={})[0]
model_logp_fn = logp_fn

initial_points = []
for ipfn, seed in zip(ipfns, seeds):
Expand Down Expand Up @@ -1501,13 +1500,18 @@ def init_nuts(

logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs)
logp_dlogp_func.trust_input = True

def model_logp_fn(ip: PointType) -> np.ndarray:
q, _ = DictToArrayBijection.map(ip)
return logp_dlogp_func([q], extra_vars={})[0]

initial_points = _init_jitter(
model,
initvals,
seeds=random_seed_list,
jitter="jitter" in init,
jitter_max_retries=jitter_max_retries,
logp_dlogp_func=logp_dlogp_func,
logp_fn=model_logp_fn,
)

apoints = [DictToArrayBijection.map(point) for point in initial_points]
Expand Down
Loading