Skip to content

Commit caa9501

Browse files
fix: deep copy nuts_sampler_kwarg to prevent pop side effects
1 parent e6767ab commit caa9501

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

pymc/sampling/mcmc.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,11 @@ def _sample_external_nuts(
309309
nuts_sampler_kwargs: dict | None,
310310
**kwargs,
311311
):
312-
if nuts_sampler_kwargs is None:
313-
nuts_sampler_kwargs = {}
312+
import copy
313+
314+
nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs)
315+
if nuts_sampler_kwargs_copy is None:
316+
nuts_sampler_kwargs_copy = {}
314317

315318
if sampler == "nutpie":
316319
try:
@@ -339,8 +342,8 @@ def _sample_external_nuts(
339342
)
340343
compile_kwargs = {}
341344
for kwarg in ("backend", "gradient_backend"):
342-
if kwarg in nuts_sampler_kwargs:
343-
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
345+
if kwarg in nuts_sampler_kwargs_copy:
346+
compile_kwargs[kwarg] = nuts_sampler_kwargs_copy.pop(kwarg)
344347
compiled_model = nutpie.compile_pymc_model(
345348
model,
346349
**compile_kwargs,
@@ -354,7 +357,7 @@ def _sample_external_nuts(
354357
target_accept=target_accept,
355358
seed=_get_seeds_per_chain(random_seed, 1)[0],
356359
progress_bar=progressbar,
357-
**nuts_sampler_kwargs,
360+
**nuts_sampler_kwargs_copy,
358361
)
359362
t_sample = time.time() - t_start
360363
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
@@ -406,7 +409,7 @@ def _sample_external_nuts(
406409
nuts_sampler=sampler,
407410
idata_kwargs=idata_kwargs,
408411
compute_convergence_checks=compute_convergence_checks,
409-
**nuts_sampler_kwargs,
412+
**nuts_sampler_kwargs_copy,
410413
)
411414
return idata
412415

@@ -686,6 +689,9 @@ def sample(
686689
mean sd hdi_3% hdi_97%
687690
p 0.609 0.047 0.528 0.699
688691
"""
692+
import copy
693+
694+
nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs)
689695
if "start" in kwargs:
690696
if initvals is not None:
691697
raise ValueError("Passing both `start` and `initvals` is not supported.")
@@ -695,8 +701,8 @@ def sample(
695701
stacklevel=2,
696702
)
697703
initvals = kwargs.pop("start")
698-
if nuts_sampler_kwargs is None:
699-
nuts_sampler_kwargs = {}
704+
if nuts_sampler_kwargs_copy is None:
705+
nuts_sampler_kwargs_copy = {}
700706
if "target_accept" in kwargs:
701707
if "nuts" in kwargs and "target_accept" in kwargs["nuts"]:
702708
raise ValueError(
@@ -808,7 +814,7 @@ def joined_blas_limiter():
808814
progressbar=progressbar,
809815
idata_kwargs=idata_kwargs,
810816
compute_convergence_checks=compute_convergence_checks,
811-
nuts_sampler_kwargs=nuts_sampler_kwargs,
817+
nuts_sampler_kwargs=nuts_sampler_kwargs_copy,
812818
**kwargs,
813819
)
814820

0 commit comments

Comments
 (0)