@@ -144,13 +144,15 @@ def get_jaxified_graph(
144144 return jax_funcify (fgraph )
145145
146146
147- def get_jaxified_logp (model : Model , negative_logp = True ) -> Callable :
147+ def get_jaxified_logp (
148+ model : Model , negative_logp = True
149+ ) -> Callable [[Sequence [np .ndarray ]], np .ndarray ]:
148150 model_logp = model .logp ()
149151 if not negative_logp :
150152 model_logp = - model_logp
151153 logp_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [model_logp ])
152154
153- def logp_fn_wrap (x ) :
155+ def logp_fn_wrap (x : Sequence [ np . ndarray ]) -> np . ndarray :
154156 return logp_fn (* x )[0 ]
155157
156158 return logp_fn_wrap
@@ -211,23 +213,39 @@ def _get_batched_jittered_initial_points(
211213 chains : int ,
212214 initvals : StartDict | Sequence [StartDict | None ] | None ,
213215 random_seed : RandomSeed ,
216+ logp_fn : Callable [[Sequence [np .ndarray ]], np .ndarray ],
214217 jitter : bool = True ,
215218 jitter_max_retries : int = 10 ,
216219) -> np .ndarray | list [np .ndarray ]:
217- """Get jittered initial point in format expected by NumPyro MCMC kernel.
220+ """Get jittered initial point in format expected by Jax MCMC kernel.
221+
222+ Parameters
223+ ----------
224+ logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
225+ Jaxified logp function
218226
219227 Returns
220228 -------
221- out: list of ndarrays
229+ out: list[np.ndarray]
222230 list with one item per variable and number of chains as batch dimension.
223231 Each item has shape `(chains, *var.shape)`
224232 """
233+
234+ def eval_logp_initial_point (point : dict [str , np .ndarray ]) -> np .ndarray :
235+ """Wrap logp_fn to conform to _init_jitter logic.
236+
237+ Wraps jaxified logp function to accept a dict of
238+ {model_variable: np.array} key:value pairs.
239+ """
240+ return logp_fn (point .values ())
241+
225242 initial_points = _init_jitter (
226243 model ,
227244 initvals ,
228245 seeds = _get_seeds_per_chain (random_seed , chains ),
229246 jitter = jitter ,
230247 jitter_max_retries = jitter_max_retries ,
248+ logp_fn = eval_logp_initial_point ,
231249 )
232250 initial_points_values = [list (initial_point .values ()) for initial_point in initial_points ]
233251 if chains == 1 :
@@ -236,7 +254,7 @@ def _get_batched_jittered_initial_points(
236254
237255
238256def _blackjax_inference_loop (
239- seed , init_position , logprob_fn , draws , tune , target_accept , ** adaptation_kwargs
257+ seed , init_position , logp_fn , draws , tune , target_accept , ** adaptation_kwargs
240258):
241259 import blackjax
242260
@@ -252,13 +270,13 @@ def _blackjax_inference_loop(
252270
253271 adapt = blackjax .window_adaptation (
254272 algorithm = algorithm ,
255- logdensity_fn = logprob_fn ,
273+ logdensity_fn = logp_fn ,
256274 target_acceptance_rate = target_accept ,
257275 adaptation_info_fn = get_filter_adapt_info_fn (),
258276 ** adaptation_kwargs ,
259277 )
260278 (last_state , tuned_params ), _ = adapt .run (seed , init_position , num_steps = tune )
261- kernel = algorithm (logprob_fn , ** tuned_params ).step
279+ kernel = algorithm (logp_fn , ** tuned_params ).step
262280
263281 def _one_step (state , xs ):
264282 _ , rng_key = xs
@@ -292,8 +310,9 @@ def _sample_blackjax_nuts(
292310 chain_method : str | None ,
293311 progressbar : bool ,
294312 random_seed : int ,
295- initial_points ,
313+ initial_points : np . ndarray | list [ np . ndarray ] ,
296314 nuts_kwargs ,
315+ logp_fn : Callable [[Sequence [np .ndarray ]], np .ndarray ] | None = None ,
297316) -> az .InferenceData :
298317 """
299318 Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
@@ -366,15 +385,16 @@ def _sample_blackjax_nuts(
366385 if chains == 1 :
367386 initial_points = [np .stack (init_state ) for init_state in zip (initial_points )]
368387
369- logprob_fn = get_jaxified_logp (model )
388+ if logp_fn is None :
389+ logp_fn = get_jaxified_logp (model )
370390
371391 seed = jax .random .PRNGKey (random_seed )
372392 keys = jax .random .split (seed , chains )
373393
374394 nuts_kwargs ["progress_bar" ] = progressbar
375395 get_posterior_samples = partial (
376396 _blackjax_inference_loop ,
377- logprob_fn = logprob_fn ,
397+ logp_fn = logp_fn ,
378398 tune = tune ,
379399 draws = draws ,
380400 target_accept = target_accept ,
@@ -415,14 +435,16 @@ def _sample_numpyro_nuts(
415435 chain_method : str | None ,
416436 progressbar : bool ,
417437 random_seed : int ,
418- initial_points ,
438+ initial_points : np . ndarray | list [ np . ndarray ] ,
419439 nuts_kwargs : dict [str , Any ],
440+ logp_fn : Callable | None = None ,
420441):
421442 import numpyro
422443
423444 from numpyro .infer import MCMC , NUTS
424445
425- logp_fn = get_jaxified_logp (model , negative_logp = False )
446+ if logp_fn is None :
447+ logp_fn = get_jaxified_logp (model , negative_logp = False )
426448
427449 nuts_kwargs .setdefault ("adapt_step_size" , True )
428450 nuts_kwargs .setdefault ("adapt_mass_matrix" , True )
@@ -590,6 +612,15 @@ def sample_jax_nuts(
590612 get_default_varnames (filtered_var_names , include_transformed = keep_untransformed )
591613 )
592614
615+ if nuts_sampler == "numpyro" :
616+ sampler_fn = _sample_numpyro_nuts
617+ logp_fn = get_jaxified_logp (model , negative_logp = False )
618+ elif nuts_sampler == "blackjax" :
619+ sampler_fn = _sample_blackjax_nuts
620+ logp_fn = get_jaxified_logp (model )
621+ else :
622+ raise ValueError (f"{ nuts_sampler = } not recognized" )
623+
593624 (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
594625
595626 initial_points = _get_batched_jittered_initial_points (
@@ -598,15 +629,9 @@ def sample_jax_nuts(
598629 initvals = initvals ,
599630 random_seed = random_seed ,
600631 jitter = jitter ,
632+ logp_fn = logp_fn ,
601633 )
602634
603- if nuts_sampler == "numpyro" :
604- sampler_fn = _sample_numpyro_nuts
605- elif nuts_sampler == "blackjax" :
606- sampler_fn = _sample_blackjax_nuts
607- else :
608- raise ValueError (f"{ nuts_sampler = } not recognized" )
609-
610635 tic1 = datetime .now ()
611636 raw_mcmc_samples , sample_stats , library = sampler_fn (
612637 model = model ,
0 commit comments