@@ -1338,7 +1338,6 @@ def _init_jitter(
13381338 seeds : Sequence [int ] | np .ndarray ,
13391339 jitter : bool ,
13401340 jitter_max_retries : int ,
1341- logp_dlogp_func = None ,
13421341 logp_fn : Callable [[PointType ], np .ndarray ] | None = None ,
13431342) -> list [PointType ]:
13441343 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
@@ -1354,8 +1353,9 @@ def _init_jitter(
13541353 Whether to apply jitter or not.
13551354 jitter_max_retries : int
13561355 Maximum number of repeated attempts at initializing values (per chain).
1357- logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray]
1356+ logp_fn: Callable[[dict[str, np.ndarray]], np.ndarray] | None
13581357 Jaxified logp function that takes the output of the initial point functions as input.
1358+ If None, will use the results of model.compile_logp().
13591359
13601360 Returns
13611361 -------
@@ -1372,19 +1372,10 @@ def _init_jitter(
13721372 if not jitter :
13731373 return [ipfn (seed ) for ipfn , seed in zip (ipfns , seeds )]
13741374
1375- model_logp_fn : Callable [[PointType ], np .ndarray ]
1376- if logp_dlogp_func is None :
1377- if logp_fn is None :
1378- # pymc NUTS path
1379- model_logp_fn = model .compile_logp ()
1380- else :
1381- # Jax path
1382- model_logp_fn = logp_fn
1375+ if logp_fn is None :
1376+ model_logp_fn = model .compile_logp ()
13831377 else :
1384-
1385- def model_logp_fn (ip : PointType ) -> np .ndarray :
1386- q , _ = DictToArrayBijection .map (ip )
1387- return logp_dlogp_func ([q ], extra_vars = {})[0 ]
1378+ model_logp_fn = logp_fn
13881379
13891380 initial_points = []
13901381 for ipfn , seed in zip (ipfns , seeds ):
@@ -1509,13 +1500,18 @@ def init_nuts(
15091500
15101501 logp_dlogp_func = model .logp_dlogp_function (ravel_inputs = True , ** compile_kwargs )
15111502 logp_dlogp_func .trust_input = True
1503+
1504+ def model_logp_fn (ip : PointType ) -> np .ndarray :
1505+ q , _ = DictToArrayBijection .map (ip )
1506+ return logp_dlogp_func ([q ], extra_vars = {})[0 ]
1507+
15121508 initial_points = _init_jitter (
15131509 model ,
15141510 initvals ,
15151511 seeds = random_seed_list ,
15161512 jitter = "jitter" in init ,
15171513 jitter_max_retries = jitter_max_retries ,
1518- logp_dlogp_func = logp_dlogp_func ,
1514+ logp_fn = model_logp_fn ,
15191515 )
15201516
15211517 apoints = [DictToArrayBijection .map (point ) for point in initial_points ]
0 commit comments