We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 47503bf commit 9365663Copy full SHA for 9365663
pymc/sampling_jax.py
@@ -151,6 +151,7 @@ def sample_numpyro_nuts(
151
keep_untransformed=False,
152
chain_method="parallel",
153
idata_kwargs=None,
154
+ nuts_kwargs=None,
155
):
156
from numpyro.infer import MCMC, NUTS
157
@@ -185,12 +186,15 @@ def sample_numpyro_nuts(
185
186
187
logp_fn = get_jaxified_logp(model)
188
189
+ if nuts_kwargs is None:
190
+ nuts_kwargs = {}
191
nuts_kernel = NUTS(
192
potential_fn=logp_fn,
193
target_accept_prob=target_accept,
194
adapt_step_size=True,
195
adapt_mass_matrix=True,
196
dense_mass=False,
197
+ **nuts_kwargs,
198
)
199
200
pmap_numpyro = MCMC(
0 commit comments