Skip to content

Commit 1fb9df1

Browse files
author
Goose
committed
refactor init_jitter inputs
1 parent 3996a06 commit 1fb9df1

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

pymc/sampling/jax.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _get_batched_jittered_initial_points(
213213
chains: int,
214214
initvals: StartDict | Sequence[StartDict | None] | None,
215215
random_seed: RandomSeed,
216-
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray],
216+
logp_fn: Callable[[Sequence[np.ndarray]], np.ndarray] | None = None,
217217
jitter: bool = True,
218218
jitter_max_retries: int = 10,
219219
) -> np.ndarray | list[np.ndarray]:
@@ -230,14 +230,18 @@ def _get_batched_jittered_initial_points(
230230
list with one item per variable and number of chains as batch dimension.
231231
Each item has shape `(chains, *var.shape)`
232232
"""
233+
if logp_fn is None:
234+
eval_logp_initial_point = None
235+
236+
else:
233237

234-
def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray:
235-
"""Wrap logp_fn to conform to _init_jitter logic.
238+
def eval_logp_initial_point(point: dict[str, np.ndarray]) -> np.ndarray:
239+
"""Wrap logp_fn to conform to _init_jitter logic.
236240
237-
Wraps jaxified logp function to accept a dict of
238-
{model_variable: np.array} key:value pairs.
239-
"""
240-
return logp_fn(point.values())
241+
Wraps jaxified logp function to accept a dict of
242+
{model_variable: np.array} key:value pairs.
243+
"""
244+
return logp_fn(point.values())
241245

242246
initial_points = _init_jitter(
243247
model,

pymc/sampling/mcmc.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,6 @@ def _init_jitter(
13381338
seeds: Sequence[int] | np.ndarray,
13391339
jitter: bool,
13401340
jitter_max_retries: int,
1341-
logp_dlogp_func=None,
13421341
logp_fn: Callable[[PointType], np.ndarray] | None = None,
13431342
) -> list[PointType]:
13441343
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
@@ -1354,8 +1353,9 @@ def _init_jitter(
13541353
Whether to apply jitter or not.
13551354
jitter_max_retries : int
13561355
Maximum number of repeated attempts at initializing values (per chain).
1357-
logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray]
1356+
logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] | None
13581357
Jaxified logp function that takes the output of the initial point functions as input.
1358+
If None, will use the results of model.compile_logp().
13591359
13601360
Returns
13611361
-------
@@ -1372,19 +1372,10 @@ def _init_jitter(
13721372
if not jitter:
13731373
return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
13741374

1375-
model_logp_fn: Callable[[PointType], np.ndarray]
1376-
if logp_dlogp_func is None:
1377-
if logp_fn is None:
1378-
# pymc NUTS path
1379-
model_logp_fn = model.compile_logp()
1380-
else:
1381-
# Jax path
1382-
model_logp_fn = logp_fn
1375+
if logp_fn is None:
1376+
model_logp_fn = model.compile_logp()
13831377
else:
1384-
1385-
def model_logp_fn(ip: PointType) -> np.ndarray:
1386-
q, _ = DictToArrayBijection.map(ip)
1387-
return logp_dlogp_func([q], extra_vars={})[0]
1378+
model_logp_fn = logp_fn
13881379

13891380
initial_points = []
13901381
for ipfn, seed in zip(ipfns, seeds):
@@ -1509,13 +1500,18 @@ def init_nuts(
15091500

15101501
logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs)
15111502
logp_dlogp_func.trust_input = True
1503+
1504+
def model_logp_fn(ip: PointType) -> np.ndarray:
1505+
q, _ = DictToArrayBijection.map(ip)
1506+
return logp_dlogp_func([q], extra_vars={})[0]
1507+
15121508
initial_points = _init_jitter(
15131509
model,
15141510
initvals,
15151511
seeds=random_seed_list,
15161512
jitter="jitter" in init,
15171513
jitter_max_retries=jitter_max_retries,
1518-
logp_dlogp_func=logp_dlogp_func,
1514+
logp_fn=model_logp_fn,
15191515
)
15201516

15211517
apoints = [DictToArrayBijection.map(point) for point in initial_points]

0 commit comments

Comments
 (0)