Skip to content

Commit 31bf864

Browse files
author
Goose
committed
use jaxified logp for initial point evaluation when sampling via Jax
1 parent 6cdfc30 commit 31bf864

File tree

5 files changed

+79
-31
lines changed

5 files changed

+79
-31
lines changed

pymc/initial_point.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def make_initial_point_fns_per_chain(
6767
overrides: StartDict | Sequence[StartDict | None] | None,
6868
jitter_rvs: set[TensorVariable] | None = None,
6969
chains: int,
70-
) -> list[Callable]:
70+
) -> list[Callable[[int], PointType]]:
7171
"""Create an initial point function for each chain, as defined by initvals.
7272
7373
If a single initval dictionary is passed, the function is replicated for each
@@ -82,6 +82,11 @@ def make_initial_point_fns_per_chain(
8282
Random variable tensors for which U(-1, 1) jitter shall be applied.
8383
(To the transformed space if applicable.)
8484
85+
Returns
86+
-------
87+
ipfns : list[Callable[[int], dict[str, np.ndarray]]]
88+
list of functions that return initial points for each chain.
89+
8590
Raises
8691
------
8792
ValueError
@@ -124,7 +129,7 @@ def make_initial_point_fn(
124129
jitter_rvs: set[TensorVariable] | None = None,
125130
default_strategy: str = "support_point",
126131
return_transformed: bool = True,
127-
) -> Callable:
132+
) -> Callable[[int], PointType]:
128133
"""Create seeded function that computes initial values for all free model variables.
129134
130135
Parameters
@@ -138,6 +143,10 @@ def make_initial_point_fn(
138143
Initial value (strategies) to use instead of what's specified in `Model.initial_values`.
139144
return_transformed : bool
140145
If `True` the returned variables will correspond to transformed initial values.
146+
147+
Returns
148+
-------
149+
initial_point_fn : Callable[[int], dict[str, np.ndarray]]
141150
"""
142151
sdict_overrides = convert_str_to_rv_dict(model, overrides or {})
143152
initval_strats = {

pymc/model/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import types
2020
import warnings
2121

22-
from collections.abc import Iterable, Sequence
22+
from collections.abc import Callable, Iterable, Sequence
2323
from typing import (
2424
Literal,
2525
cast,
@@ -585,7 +585,7 @@ def compile_logp(
585585
jacobian: bool = True,
586586
sum: bool = True,
587587
**compile_kwargs,
588-
) -> PointFunc:
588+
) -> Callable[[PointType], np.ndarray]:
589589
"""Compiled log probability density function.
590590
591591
Parameters

pymc/sampling/jax.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,15 @@ def get_jaxified_graph(
144144
return jax_funcify(fgraph)
145145

146146

147-
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
147+
def get_jaxified_logp(
148+
model: Model, negative_logp=True
149+
) -> Callable[[Sequence[np.ndarray]], np.ndarray]:
148150
model_logp = model.logp()
149151
if not negative_logp:
150152
model_logp = -model_logp
151153
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
152154

153-
def logp_fn_wrap(x):
155+
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray:
154156
return logp_fn(*x)[0]
155157

156158
return logp_fn_wrap
@@ -211,23 +213,39 @@ def _get_batched_jittered_initial_points(
211213
chains: int,
212214
initvals: StartDict | Sequence[StartDict | None] | None,
213215
random_seed: RandomSeed,
216+
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray],
214217
jitter: bool = True,
215218
jitter_max_retries: int = 10,
216219
) -> np.ndarray | list[np.ndarray]:
217-
"""Get jittered initial point in format expected by NumPyro MCMC kernel.
220+
"""Get jittered initial point in format expected by Jax MCMC kernel.
221+
222+
Parameters
223+
----------
224+
logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
225+
Jaxified logp function
218226
219227
Returns
220228
-------
221-
out: list of ndarrays
229+
out: list[np.ndarray]
222230
list with one item per variable and number of chains as batch dimension.
223231
Each item has shape `(chains, *var.shape)`
224232
"""
233+
234+
def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray:
235+
"""Wrap logp_fn to conform to _init_jitter logic.
236+
237+
Wraps jaxified logp function to accept a dict of
238+
{model_variable: np.array} key:value pairs.
239+
"""
240+
return logp_fn(point.values())
241+
225242
initial_points = _init_jitter(
226243
model,
227244
initvals,
228245
seeds=_get_seeds_per_chain(random_seed, chains),
229246
jitter=jitter,
230247
jitter_max_retries=jitter_max_retries,
248+
logp_fn=eval_logp_initial_point,
231249
)
232250
initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
233251
if chains == 1:
@@ -236,7 +254,7 @@ def _get_batched_jittered_initial_points(
236254

237255

238256
def _blackjax_inference_loop(
239-
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
257+
seed, init_position, logp_fn, draws, tune, target_accept, **adaptation_kwargs
240258
):
241259
import blackjax
242260

@@ -252,13 +270,13 @@ def _blackjax_inference_loop(
252270

253271
adapt = blackjax.window_adaptation(
254272
algorithm=algorithm,
255-
logdensity_fn=logprob_fn,
273+
logdensity_fn=logp_fn,
256274
target_acceptance_rate=target_accept,
257275
adaptation_info_fn=get_filter_adapt_info_fn(),
258276
**adaptation_kwargs,
259277
)
260278
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
261-
kernel = algorithm(logprob_fn, **tuned_params).step
279+
kernel = algorithm(logp_fn, **tuned_params).step
262280

263281
def _one_step(state, xs):
264282
_, rng_key = xs
@@ -292,8 +310,9 @@ def _sample_blackjax_nuts(
292310
chain_method: str | None,
293311
progressbar: bool,
294312
random_seed: int,
295-
initial_points,
313+
initial_points: np.ndarray | list[np.ndarray],
296314
nuts_kwargs,
315+
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None,
297316
) -> az.InferenceData:
298317
"""
299318
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
@@ -366,15 +385,16 @@ def _sample_blackjax_nuts(
366385
if chains == 1:
367386
initial_points = [np.stack(init_state) for init_state in zip(initial_points)]
368387

369-
logprob_fn = get_jaxified_logp(model)
388+
if logp_fn is None:
389+
logp_fn = get_jaxified_logp(model)
370390

371391
seed = jax.random.PRNGKey(random_seed)
372392
keys = jax.random.split(seed, chains)
373393

374394
nuts_kwargs["progress_bar"] = progressbar
375395
get_posterior_samples = partial(
376396
_blackjax_inference_loop,
377-
logprob_fn=logprob_fn,
397+
logp_fn=logp_fn,
378398
tune=tune,
379399
draws=draws,
380400
target_accept=target_accept,
@@ -415,14 +435,16 @@ def _sample_numpyro_nuts(
415435
chain_method: str | None,
416436
progressbar: bool,
417437
random_seed: int,
418-
initial_points,
438+
initial_points: np.ndarray | list[np.ndarray],
419439
nuts_kwargs: dict[str, Any],
440+
logp_fn: Callable | None = None,
420441
):
421442
import numpyro
422443

423444
from numpyro.infer import MCMC, NUTS
424445

425-
logp_fn = get_jaxified_logp(model, negative_logp=False)
446+
if logp_fn is None:
447+
logp_fn = get_jaxified_logp(model, negative_logp=False)
426448

427449
nuts_kwargs.setdefault("adapt_step_size", True)
428450
nuts_kwargs.setdefault("adapt_mass_matrix", True)
@@ -590,6 +612,15 @@ def sample_jax_nuts(
590612
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
591613
)
592614

615+
if nuts_sampler == "numpyro":
616+
sampler_fn = _sample_numpyro_nuts
617+
logp_fn = get_jaxified_logp(model, negative_logp=False)
618+
elif nuts_sampler == "blackjax":
619+
sampler_fn = _sample_blackjax_nuts
620+
logp_fn = get_jaxified_logp(model)
621+
else:
622+
raise ValueError(f"{nuts_sampler=} not recognized")
623+
593624
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
594625

595626
initial_points = _get_batched_jittered_initial_points(
@@ -598,15 +629,9 @@ def sample_jax_nuts(
598629
initvals=initvals,
599630
random_seed=random_seed,
600631
jitter=jitter,
632+
logp_fn=logp_fn,
601633
)
602634

603-
if nuts_sampler == "numpyro":
604-
sampler_fn = _sample_numpyro_nuts
605-
elif nuts_sampler == "blackjax":
606-
sampler_fn = _sample_blackjax_nuts
607-
else:
608-
raise ValueError(f"{nuts_sampler=} not recognized")
609-
610635
tic1 = datetime.now()
611636
raw_mcmc_samples, sample_stats, library = sampler_fn(
612637
model=model,

pymc/sampling/mcmc.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,7 @@ def _init_jitter(
13391339
jitter: bool,
13401340
jitter_max_retries: int,
13411341
logp_dlogp_func=None,
1342+
logp_fn: Callable[[PointType], np.ndarray] | None = None,
13421343
) -> list[PointType]:
13431344
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
13441345
@@ -1353,11 +1354,13 @@ def _init_jitter(
13531354
Whether to apply jitter or not.
13541355
jitter_max_retries : int
13551356
Maximum number of repeated attempts at initializing values (per chain).
1357+
logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray]
1358+
Jaxified logp function that takes the output of the initial point functions as input.
13561359
13571360
Returns
13581361
-------
1359-
start : ``pymc.model.Point``
1360-
Starting point for sampler
1362+
initial_points : list[dict[str, np.ndarray]]
1363+
List of starting points for the sampler
13611364
"""
13621365
ipfns = make_initial_point_fns_per_chain(
13631366
model=model,
@@ -1369,12 +1372,17 @@ def _init_jitter(
13691372
if not jitter:
13701373
return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
13711374

1372-
model_logp_fn: Callable
1375+
model_logp_fn: Callable[[PointType], np.ndarray]
13731376
if logp_dlogp_func is None:
1374-
model_logp_fn = model.compile_logp()
1377+
if logp_fn is None:
1378+
# pymc NUTS path
1379+
model_logp_fn = model.compile_logp()
1380+
else:
1381+
# Jax path
1382+
model_logp_fn = logp_fn
13751383
else:
13761384

1377-
def model_logp_fn(ip):
1385+
def model_logp_fn(ip: PointType) -> np.ndarray:
13781386
q, _ = DictToArrayBijection.map(ip)
13791387
return logp_dlogp_func([q], extra_vars={})[0]
13801388

tests/sampling/test_jax.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,20 +333,26 @@ def test_get_batched_jittered_initial_points():
333333
with pm.Model() as model:
334334
x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3)))
335335

336+
logp_fn = get_jaxified_logp(model)
337+
336338
# No jitter
337339
ips = _get_batched_jittered_initial_points(
338-
model=model, chains=1, random_seed=1, initvals=None, jitter=False
340+
model=model, chains=1, random_seed=1, initvals=None, jitter=False, logp_fn=logp_fn
339341
)
340342
assert np.all(ips[0] == 0)
341343

342344
# Single chain
343-
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
345+
ips = _get_batched_jittered_initial_points(
346+
model=model, chains=1, random_seed=1, initvals=None, logp_fn=logp_fn
347+
)
344348

345349
assert ips[0].shape == (2, 3)
346350
assert np.all(ips[0] != 0)
347351

348352
# Multiple chains
349-
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
353+
ips = _get_batched_jittered_initial_points(
354+
model=model, chains=2, random_seed=1, initvals=None, logp_fn=logp_fn
355+
)
350356

351357
assert ips[0].shape == (2, 2, 3)
352358
assert np.all(ips[0][0] != ips[0][1])

0 commit comments

Comments
 (0)