@@ -264,6 +264,7 @@ def _sample_external_nuts(
264
264
random_seed : Union [RandomState , None ],
265
265
initvals : Union [StartDict , Sequence [Optional [StartDict ]], None ],
266
266
model : Model ,
267
+ var_names : Optional [Sequence [str ]],
267
268
progressbar : bool ,
268
269
idata_kwargs : Optional [dict ],
269
270
nuts_sampler_kwargs : Optional [dict ],
@@ -292,6 +293,11 @@ def _sample_external_nuts(
292
293
"`idata_kwargs` are currently ignored by the nutpie sampler" ,
293
294
UserWarning ,
294
295
)
296
+ if var_names is not None :
297
+ warnings .warn (
298
+ "`var_names` are currently ignored by the nutpie sampler" ,
299
+ UserWarning ,
300
+ )
295
301
compiled_model = nutpie .compile_pymc_model (model )
296
302
t_start = time .time ()
297
303
idata = nutpie .sample (
@@ -348,6 +354,7 @@ def _sample_external_nuts(
348
354
random_seed = random_seed ,
349
355
initvals = initvals ,
350
356
model = model ,
357
+ var_names = var_names ,
351
358
progressbar = progressbar ,
352
359
nuts_sampler = sampler ,
353
360
idata_kwargs = idata_kwargs ,
@@ -371,6 +378,7 @@ def sample(
371
378
random_seed : RandomState = None ,
372
379
progressbar : bool = True ,
373
380
step = None ,
381
+ var_names : Optional [Sequence [str ]] = None ,
374
382
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
375
383
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
376
384
init : str = "auto" ,
@@ -399,6 +407,7 @@ def sample(
399
407
random_seed : RandomState = None ,
400
408
progressbar : bool = True ,
401
409
step = None ,
410
+ var_names : Optional [Sequence [str ]] = None ,
402
411
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
403
412
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
404
413
init : str = "auto" ,
@@ -427,6 +436,7 @@ def sample(
427
436
random_seed : RandomState = None ,
428
437
progressbar : bool = True ,
429
438
step = None ,
439
+ var_names : Optional [Sequence [str ]] = None ,
430
440
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
431
441
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
432
442
init : str = "auto" ,
@@ -478,6 +488,8 @@ def sample(
478
488
A step function or collection of functions. If there are variables without step methods,
479
489
step methods for those variables will be assigned automatically. By default the NUTS step
480
490
method will be used, if appropriate to the model.
491
+ var_names : list of str, optional
492
+ Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
481
493
nuts_sampler : str
482
494
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
483
495
This requires the chosen sampler to be installed.
@@ -680,6 +692,7 @@ def sample(
680
692
random_seed = random_seed ,
681
693
initvals = initvals ,
682
694
model = model ,
695
+ var_names = var_names ,
683
696
progressbar = progressbar ,
684
697
idata_kwargs = idata_kwargs ,
685
698
nuts_sampler_kwargs = nuts_sampler_kwargs ,
@@ -722,12 +735,19 @@ def sample(
722
735
model .check_start_vals (ip )
723
736
_check_start_shape (model , ip )
724
737
738
+ if var_names is not None :
739
+ trace_vars = [v for v in model .unobserved_RVs if v .name in var_names ]
740
+ assert len (trace_vars ) == len (var_names ), "Not all var_names were found in the model"
741
+ else :
742
+ trace_vars = None
743
+
725
744
# Create trace backends for each chain
726
745
run , traces = init_traces (
727
746
backend = trace ,
728
747
chains = chains ,
729
748
expected_length = draws + tune ,
730
749
step = step ,
750
+ trace_vars = trace_vars ,
731
751
initial_point = ip ,
732
752
model = model ,
733
753
)
@@ -739,6 +759,7 @@ def sample(
739
759
"traces" : traces ,
740
760
"chains" : chains ,
741
761
"tune" : tune ,
762
+ "var_names" : var_names ,
742
763
"progressbar" : progressbar ,
743
764
"model" : model ,
744
765
"cores" : cores ,
0 commit comments