-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Use jaxified logp for initial point evaluation when sampling via Jax #7610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
31bf864
3996a06
1fb9df1
f71aedc
2855587
85996a1
2e9d7db
e50b5b4
f6fbb0a
deea64c
ae0fb96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,13 +144,15 @@ def get_jaxified_graph( | |
return jax_funcify(fgraph) | ||
|
||
|
||
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable: | ||
def get_jaxified_logp( | ||
model: Model, negative_logp=True | ||
) -> Callable[[Sequence[np.ndarray]], np.ndarray]: | ||
model_logp = model.logp() | ||
if not negative_logp: | ||
model_logp = -model_logp | ||
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) | ||
|
||
def logp_fn_wrap(x): | ||
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray: | ||
|
||
return logp_fn(*x)[0] | ||
|
||
return logp_fn_wrap | ||
|
@@ -211,23 +213,39 @@ def _get_batched_jittered_initial_points( | |
chains: int, | ||
initvals: StartDict | Sequence[StartDict | None] | None, | ||
random_seed: RandomSeed, | ||
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray], | ||
jitter: bool = True, | ||
jitter_max_retries: int = 10, | ||
) -> np.ndarray | list[np.ndarray]: | ||
"""Get jittered initial point in format expected by NumPyro MCMC kernel. | ||
"""Get jittered initial point in format expected by Jax MCMC kernel. | ||
|
||
Parameters | ||
---------- | ||
logp_fn : Callable[Sequence[np.ndarray]], np.ndarray] | ||
Jaxified logp function | ||
|
||
Returns | ||
------- | ||
out: list of ndarrays | ||
out: list[np.ndarray] | ||
nataziel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
list with one item per variable and number of chains as batch dimension. | ||
Each item has shape `(chains, *var.shape)` | ||
""" | ||
|
||
def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray: | ||
"""Wrap logp_fn to conform to _init_jitter logic. | ||
|
||
Wraps jaxified logp function to accept a dict of | ||
{model_variable: np.array} key:value pairs. | ||
""" | ||
return logp_fn(point.values()) | ||
|
||
initial_points = _init_jitter( | ||
model, | ||
initvals, | ||
seeds=_get_seeds_per_chain(random_seed, chains), | ||
jitter=jitter, | ||
jitter_max_retries=jitter_max_retries, | ||
logp_fn=eval_logp_initial_point, | ||
) | ||
initial_points_values = [list(initial_point.values()) for initial_point in initial_points] | ||
if chains == 1: | ||
|
@@ -236,7 +254,7 @@ def _get_batched_jittered_initial_points( | |
|
||
|
||
def _blackjax_inference_loop( | ||
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs | ||
seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs | ||
): | ||
import blackjax | ||
|
||
|
@@ -252,13 +270,13 @@ def _blackjax_inference_loop( | |
|
||
adapt = blackjax.window_adaptation( | ||
algorithm=algorithm, | ||
logdensity_fn=logprob_fn, | ||
logdensity_fn=logp_fn, | ||
target_acceptance_rate=target_accept, | ||
adaptation_info_fn=get_filter_adapt_info_fn(), | ||
**adaptation_kwargs, | ||
) | ||
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) | ||
kernel = algorithm(logprob_fn, **tuned_params).step | ||
kernel = algorithm(logp_fn, **tuned_params).step | ||
|
||
def _one_step(state, xs): | ||
_, rng_key = xs | ||
|
@@ -292,8 +310,9 @@ def _sample_blackjax_nuts( | |
chain_method: str | None, | ||
progressbar: bool, | ||
random_seed: int, | ||
initial_points, | ||
initial_points: np.ndarray | list[np.ndarray], | ||
nuts_kwargs, | ||
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None, | ||
) -> az.InferenceData: | ||
""" | ||
Draw samples from the posterior using the NUTS method from the ``blackjax`` library. | ||
|
@@ -366,15 +385,16 @@ def _sample_blackjax_nuts( | |
if chains == 1: | ||
initial_points = [np.stack(init_state) for init_state in zip(initial_points)] | ||
|
||
logprob_fn = get_jaxified_logp(model) | ||
if logp_fn is None: | ||
logp_fn = get_jaxified_logp(model) | ||
|
||
seed = jax.random.PRNGKey(random_seed) | ||
keys = jax.random.split(seed, chains) | ||
|
||
nuts_kwargs["progress_bar"] = progressbar | ||
get_posterior_samples = partial( | ||
_blackjax_inference_loop, | ||
logprob_fn=logprob_fn, | ||
logp_fn=logp_fn, | ||
tune=tune, | ||
draws=draws, | ||
target_accept=target_accept, | ||
|
@@ -415,14 +435,16 @@ def _sample_numpyro_nuts( | |
chain_method: str | None, | ||
progressbar: bool, | ||
random_seed: int, | ||
initial_points, | ||
initial_points: np.ndarray | list[np.ndarray], | ||
nuts_kwargs: dict[str, Any], | ||
logp_fn: Callable | None = None, | ||
): | ||
import numpyro | ||
|
||
from numpyro.infer import MCMC, NUTS | ||
|
||
logp_fn = get_jaxified_logp(model, negative_logp=False) | ||
if logp_fn is None: | ||
logp_fn = get_jaxified_logp(model, negative_logp=False) | ||
|
||
nuts_kwargs.setdefault("adapt_step_size", True) | ||
nuts_kwargs.setdefault("adapt_mass_matrix", True) | ||
|
@@ -590,6 +612,15 @@ def sample_jax_nuts( | |
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) | ||
) | ||
|
||
if nuts_sampler == "numpyro": | ||
sampler_fn = _sample_numpyro_nuts | ||
logp_fn = get_jaxified_logp(model, negative_logp=False) | ||
elif nuts_sampler == "blackjax": | ||
sampler_fn = _sample_blackjax_nuts | ||
logp_fn = get_jaxified_logp(model) | ||
else: | ||
raise ValueError(f"{nuts_sampler=} not recognized") | ||
|
||
(random_seed,) = _get_seeds_per_chain(random_seed, 1) | ||
|
||
initial_points = _get_batched_jittered_initial_points( | ||
|
@@ -598,15 +629,9 @@ def sample_jax_nuts( | |
initvals=initvals, | ||
random_seed=random_seed, | ||
jitter=jitter, | ||
logp_fn=logp_fn, | ||
) | ||
|
||
if nuts_sampler == "numpyro": | ||
sampler_fn = _sample_numpyro_nuts | ||
elif nuts_sampler == "blackjax": | ||
sampler_fn = _sample_blackjax_nuts | ||
else: | ||
raise ValueError(f"{nuts_sampler=} not recognized") | ||
|
||
tic1 = datetime.now() | ||
raw_mcmc_samples, sample_stats, library = sampler_fn( | ||
model=model, | ||
|
Uh oh!
There was an error while loading. Please reload this page.