Skip to content

Commit 9365663

Browse files
lucianopaztwiecki
authored andcommitted
Enable users to pass extra options to numpyro.infer.NUTS in sampling_jax
1 parent 47503bf commit 9365663

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pymc/sampling_jax.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def sample_numpyro_nuts(
151151
keep_untransformed=False,
152152
chain_method="parallel",
153153
idata_kwargs=None,
154+
nuts_kwargs=None,
154155
):
155156
from numpyro.infer import MCMC, NUTS
156157

@@ -185,12 +186,15 @@ def sample_numpyro_nuts(
185186

186187
logp_fn = get_jaxified_logp(model)
187188

189+
if nuts_kwargs is None:
190+
nuts_kwargs = {}
188191
nuts_kernel = NUTS(
189192
potential_fn=logp_fn,
190193
target_accept_prob=target_accept,
191194
adapt_step_size=True,
192195
adapt_mass_matrix=True,
193196
dense_mass=False,
197+
**nuts_kwargs,
194198
)
195199

196200
pmap_numpyro = MCMC(

0 commit comments

Comments
 (0)