Skip to content

Commit df93e9d

Browse files
committed
MAINT Forward *kwargs to init_nuts to NUTS.
1 parent 808d7dd commit df93e9d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

pymc3/sampling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_see
393393
return {k: np.asarray(v) for k, v in ppc.items()}
394394

395395

396-
def init_nuts(init='advi', n_init=500000, model=None):
396+
def init_nuts(init='advi', n_init=500000, model=None, **kwargs):
397397
"""Initialize and sample from posterior of a continuous model.
398398
399399
This is a convenience function. NUTS convergence and sampling speed is extremely
@@ -413,6 +413,8 @@ def init_nuts(init='advi', n_init=500000, model=None):
413413
Number of iterations of initializer
414414
If 'advi', number of iterations, if 'metropolis', number of draws.
415415
model : Model (optional if in `with` context)
416+
**kwargs : keyword arguments
417+
Extra keyword arguments are forwarded to pymc3.NUTS.
416418
417419
Returns
418420
-------
@@ -448,6 +450,6 @@ def init_nuts(init='advi', n_init=500000, model=None):
448450
else:
449451
raise NotImplemented('Initializer {} is not supported.'.format(init))
450452

451-
step = pm.NUTS(scaling=cov, is_cov=True)
453+
step = pm.NUTS(scaling=cov, is_cov=True, **kwargs)
452454

453455
return start, step

0 commit comments

Comments
 (0)