@@ -151,6 +151,7 @@ class Metropolis(ArrayStepShared):
151
151
def __init__ (
152
152
self ,
153
153
vars = None ,
154
+ * ,
154
155
S = None ,
155
156
proposal_dist = None ,
156
157
scaling = 1.0 ,
@@ -159,7 +160,7 @@ def __init__(
159
160
model = None ,
160
161
mode = None ,
161
162
rng = None ,
162
- ** kwargs ,
163
+ blocked : bool = False ,
163
164
):
164
165
"""Create an instance of a Metropolis stepper.
165
166
@@ -251,7 +252,7 @@ def __init__(
251
252
252
253
shared = pm .make_shared_replacements (initial_values , vars , model )
253
254
self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
254
- super ().__init__ (vars , shared , rng = rng )
255
+ super ().__init__ (vars , shared , blocked = blocked , rng = rng )
255
256
256
257
def reset_tuning (self ):
257
258
"""Reset the tuned sampler parameters to their initial values."""
@@ -418,7 +419,17 @@ class BinaryMetropolis(ArrayStep):
418
419
419
420
_state_class = BinaryMetropolisState
420
421
421
- def __init__ (self , vars , scaling = 1.0 , tune = True , tune_interval = 100 , model = None , rng = None ):
422
+ def __init__ (
423
+ self ,
424
+ vars ,
425
+ * ,
426
+ scaling = 1.0 ,
427
+ tune = True ,
428
+ tune_interval = 100 ,
429
+ model = None ,
430
+ rng = None ,
431
+ blocked : bool = True ,
432
+ ):
422
433
model = pm .modelcontext (model )
423
434
424
435
self .scaling = scaling
@@ -432,7 +443,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None,
432
443
if not all (v .dtype in pm .discrete_types for v in vars ):
433
444
raise ValueError ("All variables must be Bernoulli for BinaryMetropolis" )
434
445
435
- super ().__init__ (vars , [model .compile_logp ()], rng = rng )
446
+ super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
436
447
437
448
def astep (self , apoint : RaveledVars , * args ) -> tuple [RaveledVars , StatsType ]:
438
449
logp = args [0 ]
@@ -530,7 +541,16 @@ class BinaryGibbsMetropolis(ArrayStep):
530
541
531
542
_state_class = BinaryGibbsMetropolisState
532
543
533
- def __init__ (self , vars , order = "random" , transit_p = 0.8 , model = None , rng = None ):
544
+ def __init__ (
545
+ self ,
546
+ vars ,
547
+ * ,
548
+ order = "random" ,
549
+ transit_p = 0.8 ,
550
+ model = None ,
551
+ rng = None ,
552
+ blocked : bool = True ,
553
+ ):
534
554
model = pm .modelcontext (model )
535
555
536
556
# Doesn't actually tune, but it's required to emit a sampler stat
@@ -556,7 +576,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None):
556
576
if not all (v .dtype in pm .discrete_types for v in vars ):
557
577
raise ValueError ("All variables must be binary for BinaryGibbsMetropolis" )
558
578
559
- super ().__init__ (vars , [model .compile_logp ()], rng = rng )
579
+ super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
560
580
561
581
def reset_tuning (self ):
562
582
# There are no tuning parameters in this step method.
@@ -638,7 +658,14 @@ class CategoricalGibbsMetropolis(ArrayStep):
638
658
_state_class = CategoricalGibbsMetropolisState
639
659
640
660
def __init__ (
641
- self , vars , proposal = "uniform" , order = "random" , model = None , rng : RandomGenerator = None
661
+ self ,
662
+ vars ,
663
+ * ,
664
+ proposal = "uniform" ,
665
+ order = "random" ,
666
+ model = None ,
667
+ rng : RandomGenerator = None ,
668
+ blocked : bool = True ,
642
669
):
643
670
model = pm .modelcontext (model )
644
671
@@ -693,7 +720,7 @@ def __init__(
693
720
# that indicates whether a draw was done in a tuning phase.
694
721
self .tune = True
695
722
696
- super ().__init__ (vars , [model .compile_logp ()], rng = rng )
723
+ super ().__init__ (vars , [model .compile_logp ()], blocked = blocked , rng = rng )
697
724
698
725
def reset_tuning (self ):
699
726
# There are no tuning parameters in this step method.
@@ -858,6 +885,7 @@ class DEMetropolis(PopulationArrayStepShared):
858
885
def __init__ (
859
886
self ,
860
887
vars = None ,
888
+ * ,
861
889
S = None ,
862
890
proposal_dist = None ,
863
891
lamb = None ,
@@ -867,7 +895,7 @@ def __init__(
867
895
model = None ,
868
896
mode = None ,
869
897
rng = None ,
870
- ** kwargs ,
898
+ blocked : bool = True ,
871
899
):
872
900
model = pm .modelcontext (model )
873
901
initial_values = model .initial_point ()
@@ -902,7 +930,7 @@ def __init__(
902
930
903
931
shared = pm .make_shared_replacements (initial_values , vars , model )
904
932
self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
905
- super ().__init__ (vars , shared , rng = rng )
933
+ super ().__init__ (vars , shared , blocked = blocked , rng = rng )
906
934
907
935
def astep (self , q0 : RaveledVars ) -> tuple [RaveledVars , StatsType ]:
908
936
point_map_info = q0 .point_map_info
@@ -1025,6 +1053,7 @@ class DEMetropolisZ(ArrayStepShared):
1025
1053
def __init__ (
1026
1054
self ,
1027
1055
vars = None ,
1056
+ * ,
1028
1057
S = None ,
1029
1058
proposal_dist = None ,
1030
1059
lamb = None ,
@@ -1035,7 +1064,7 @@ def __init__(
1035
1064
model = None ,
1036
1065
mode = None ,
1037
1066
rng = None ,
1038
- ** kwargs ,
1067
+ blocked : bool = True ,
1039
1068
):
1040
1069
model = pm .modelcontext (model )
1041
1070
initial_values = model .initial_point ()
@@ -1082,7 +1111,7 @@ def __init__(
1082
1111
1083
1112
shared = pm .make_shared_replacements (initial_values , vars , model )
1084
1113
self .delta_logp = delta_logp (initial_values , model .logp (), vars , shared )
1085
- super ().__init__ (vars , shared , rng = rng )
1114
+ super ().__init__ (vars , shared , blocked = blocked , rng = rng )
1086
1115
1087
1116
def reset_tuning (self ):
1088
1117
"""Reset the tuned sampler parameters and history to their initial values."""
0 commit comments