@@ -309,8 +309,11 @@ def _sample_external_nuts(
309309 nuts_sampler_kwargs : dict | None ,
310310 ** kwargs ,
311311):
312- if nuts_sampler_kwargs is None :
313- nuts_sampler_kwargs = {}
312+ import copy
313+
314+ nuts_sampler_kwargs_copy = copy .deepcopy (nuts_sampler_kwargs )
315+ if nuts_sampler_kwargs_copy is None :
316+ nuts_sampler_kwargs_copy = {}
314317
315318 if sampler == "nutpie" :
316319 try :
@@ -339,8 +342,8 @@ def _sample_external_nuts(
339342 )
340343 compile_kwargs = {}
341344 for kwarg in ("backend" , "gradient_backend" ):
342- if kwarg in nuts_sampler_kwargs :
343- compile_kwargs [kwarg ] = nuts_sampler_kwargs .pop (kwarg )
345+ if kwarg in nuts_sampler_kwargs_copy :
346+ compile_kwargs [kwarg ] = nuts_sampler_kwargs_copy .pop (kwarg )
344347 compiled_model = nutpie .compile_pymc_model (
345348 model ,
346349 ** compile_kwargs ,
@@ -354,7 +357,7 @@ def _sample_external_nuts(
354357 target_accept = target_accept ,
355358 seed = _get_seeds_per_chain (random_seed , 1 )[0 ],
356359 progress_bar = progressbar ,
357- ** nuts_sampler_kwargs ,
360+ ** nuts_sampler_kwargs_copy ,
358361 )
359362 t_sample = time .time () - t_start
360363 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
@@ -406,7 +409,7 @@ def _sample_external_nuts(
406409 nuts_sampler = sampler ,
407410 idata_kwargs = idata_kwargs ,
408411 compute_convergence_checks = compute_convergence_checks ,
409- ** nuts_sampler_kwargs ,
412+ ** nuts_sampler_kwargs_copy ,
410413 )
411414 return idata
412415
@@ -686,6 +689,9 @@ def sample(
686689 mean sd hdi_3% hdi_97%
687690 p 0.609 0.047 0.528 0.699
688691 """
692+ import copy
693+
694+ nuts_sampler_kwargs_copy = copy .deepcopy (nuts_sampler_kwargs )
689695 if "start" in kwargs :
690696 if initvals is not None :
691697 raise ValueError ("Passing both `start` and `initvals` is not supported." )
@@ -695,8 +701,8 @@ def sample(
695701 stacklevel = 2 ,
696702 )
697703 initvals = kwargs .pop ("start" )
698- if nuts_sampler_kwargs is None :
699- nuts_sampler_kwargs = {}
704+ if nuts_sampler_kwargs_copy is None :
705+ nuts_sampler_kwargs_copy = {}
700706 if "target_accept" in kwargs :
701707 if "nuts" in kwargs and "target_accept" in kwargs ["nuts" ]:
702708 raise ValueError (
@@ -808,7 +814,7 @@ def joined_blas_limiter():
808814 progressbar = progressbar ,
809815 idata_kwargs = idata_kwargs ,
810816 compute_convergence_checks = compute_convergence_checks ,
811- nuts_sampler_kwargs = nuts_sampler_kwargs ,
817+ nuts_sampler_kwargs = nuts_sampler_kwargs_copy ,
812818 ** kwargs ,
813819 )
814820
0 commit comments