Skip to content

Commit 18bbcbb

Browse files
Fix problems with specifying target_accept and nuts kwargs (#6018)
Fix problem that target_accept is overwritten if nuts kwargs are specified Raise error when target_accept is specified twice; directly and in nuts kwargs Co-authored-by: Ricardo Vieira <[email protected]>
1 parent c8db06b commit 18bbcbb

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

pymc/sampling.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,14 @@ def sample(
474474
)
475475
initvals = kwargs.pop("start")
476476
if "target_accept" in kwargs:
477-
kwargs.setdefault("nuts", {"target_accept": kwargs.pop("target_accept")})
477+
if "nuts" in kwargs and "target_accept" in kwargs["nuts"]:
478+
raise ValueError(
479+
"`target_accept` was defined twice. Please specify it either as a direct keyword argument or in the `nuts` kwarg."
480+
)
481+
if "nuts" in kwargs:
482+
kwargs["nuts"]["target_accept"] = kwargs.pop("target_accept")
483+
else:
484+
kwargs = {"nuts": {"target_accept": kwargs.pop("target_accept")}}
478485

479486
model = modelcontext(model)
480487
if not model.free_RVs:

pymc/tests/test_sampling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,9 +1450,14 @@ def test_step_args():
14501450
a = pm.Normal("a")
14511451
idata0 = pm.sample(target_accept=0.5, random_seed=1410)
14521452
idata1 = pm.sample(nuts={"target_accept": 0.5}, random_seed=1410 * 2)
1453+
idata2 = pm.sample(target_accept=0.5, nuts={"max_treedepth": 10}, random_seed=1410)
1454+
1455+
with pytest.raises(ValueError, match="`target_accept` was defined twice."):
1456+
pm.sample(target_accept=0.5, nuts={"target_accept": 0.95}, random_seed=1410)
14531457

14541458
npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
14551459
npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
1460+
npt.assert_almost_equal(idata2.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
14561461

14571462
with pm.Model() as model:
14581463
a = pm.Normal("a")

0 commit comments

Comments
 (0)