@@ -309,11 +309,9 @@ def _sample_external_nuts(
309309 nuts_sampler_kwargs : dict | None ,
310310 ** kwargs ,
311311):
312- import copy
313-
314- nuts_sampler_kwargs_copy = copy .deepcopy (nuts_sampler_kwargs )
315- if nuts_sampler_kwargs_copy is None :
316- nuts_sampler_kwargs_copy = {}
312+ nuts_sampler_kwargs = nuts_sampler_kwargs .copy ()
313+ if nuts_sampler_kwargs is None :
314+ nuts_sampler_kwargs = {}
317315
318316 if sampler == "nutpie" :
319317 try :
@@ -342,8 +340,8 @@ def _sample_external_nuts(
342340 )
343341 compile_kwargs = {}
344342 for kwarg in ("backend" , "gradient_backend" ):
345- if kwarg in nuts_sampler_kwargs_copy :
346- compile_kwargs [kwarg ] = nuts_sampler_kwargs_copy .pop (kwarg )
343+ if kwarg in nuts_sampler_kwargs :
344+ compile_kwargs [kwarg ] = nuts_sampler_kwargs .pop (kwarg )
347345 compiled_model = nutpie .compile_pymc_model (
348346 model ,
349347 ** compile_kwargs ,
@@ -357,7 +355,7 @@ def _sample_external_nuts(
357355 target_accept = target_accept ,
358356 seed = _get_seeds_per_chain (random_seed , 1 )[0 ],
359357 progress_bar = progressbar ,
360- ** nuts_sampler_kwargs_copy ,
358+ ** nuts_sampler_kwargs ,
361359 )
362360 t_sample = time .time () - t_start
363361 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
@@ -409,7 +407,7 @@ def _sample_external_nuts(
409407 nuts_sampler = sampler ,
410408 idata_kwargs = idata_kwargs ,
411409 compute_convergence_checks = compute_convergence_checks ,
412- ** nuts_sampler_kwargs_copy ,
410+ ** nuts_sampler_kwargs ,
413411 )
414412 return idata
415413
@@ -689,9 +687,7 @@ def sample(
689687 mean sd hdi_3% hdi_97%
690688 p 0.609 0.047 0.528 0.699
691689 """
692- import copy
693-
694- nuts_sampler_kwargs_copy = copy .deepcopy (nuts_sampler_kwargs )
690+ nuts_sampler_kwargs = nuts_sampler_kwargs .copy ()
695691 if "start" in kwargs :
696692 if initvals is not None :
697693 raise ValueError ("Passing both `start` and `initvals` is not supported." )
@@ -701,8 +697,8 @@ def sample(
701697 stacklevel = 2 ,
702698 )
703699 initvals = kwargs .pop ("start" )
704- if nuts_sampler_kwargs_copy is None :
705- nuts_sampler_kwargs_copy = {}
700+ if nuts_sampler_kwargs is None :
701+ nuts_sampler_kwargs = {}
706702 if "target_accept" in kwargs :
707703 if "nuts" in kwargs and "target_accept" in kwargs ["nuts" ]:
708704 raise ValueError (
@@ -814,7 +810,7 @@ def joined_blas_limiter():
814810 progressbar = progressbar ,
815811 idata_kwargs = idata_kwargs ,
816812 compute_convergence_checks = compute_convergence_checks ,
817- nuts_sampler_kwargs = nuts_sampler_kwargs_copy ,
813+ nuts_sampler_kwargs = nuts_sampler_kwargs ,
818814 ** kwargs ,
819815 )
820816
0 commit comments