Skip to content

Commit f1e586b

Browse files
committed
More strict/explicit signature in step samplers
1 parent d2761a3 commit f1e586b

File tree

5 files changed

+56
-17
lines changed

5 files changed

+56
-17
lines changed

pymc/step_methods/arraystep.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ class GradientSharedStep(ArrayStepShared):
174174
def __init__(
175175
self,
176176
vars,
177+
*,
177178
model=None,
178-
blocked=True,
179+
blocked: bool = True,
179180
dtype=None,
180181
logp_dlogp_func=None,
181182
rng: RandomGenerator = None,

pymc/step_methods/hmc/base_hmc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ class BaseHMC(GradientSharedStep):
8282
def __init__(
8383
self,
8484
vars=None,
85+
*,
8586
scaling=None,
8687
step_scale=0.25,
8788
is_cov=False,
8889
model=None,
89-
blocked=True,
90+
blocked: bool = True,
9091
potential=None,
9192
dtype=None,
9293
Emax=1000,

pymc/step_methods/metropolis.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class Metropolis(ArrayStepShared):
151151
def __init__(
152152
self,
153153
vars=None,
154+
*,
154155
S=None,
155156
proposal_dist=None,
156157
scaling=1.0,
@@ -159,7 +160,7 @@ def __init__(
159160
model=None,
160161
mode=None,
161162
rng=None,
162-
**kwargs,
163+
blocked: bool = False,
163164
):
164165
"""Create an instance of a Metropolis stepper.
165166
@@ -251,7 +252,7 @@ def __init__(
251252

252253
shared = pm.make_shared_replacements(initial_values, vars, model)
253254
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
254-
super().__init__(vars, shared, rng=rng)
255+
super().__init__(vars, shared, blocked=blocked, rng=rng)
255256

256257
def reset_tuning(self):
257258
"""Reset the tuned sampler parameters to their initial values."""
@@ -418,7 +419,17 @@ class BinaryMetropolis(ArrayStep):
418419

419420
_state_class = BinaryMetropolisState
420421

421-
def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None, rng=None):
422+
def __init__(
423+
self,
424+
vars,
425+
*,
426+
scaling=1.0,
427+
tune=True,
428+
tune_interval=100,
429+
model=None,
430+
rng=None,
431+
blocked: bool = True,
432+
):
422433
model = pm.modelcontext(model)
423434

424435
self.scaling = scaling
@@ -432,7 +443,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None,
432443
if not all(v.dtype in pm.discrete_types for v in vars):
433444
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")
434445

435-
super().__init__(vars, [model.compile_logp()], rng=rng)
446+
super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
436447

437448
def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
438449
logp = args[0]
@@ -530,7 +541,16 @@ class BinaryGibbsMetropolis(ArrayStep):
530541

531542
_state_class = BinaryGibbsMetropolisState
532543

533-
def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None):
544+
def __init__(
545+
self,
546+
vars,
547+
*,
548+
order="random",
549+
transit_p=0.8,
550+
model=None,
551+
rng=None,
552+
blocked: bool = True,
553+
):
534554
model = pm.modelcontext(model)
535555

536556
# Doesn't actually tune, but it's required to emit a sampler stat
@@ -556,7 +576,7 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None, rng=None):
556576
if not all(v.dtype in pm.discrete_types for v in vars):
557577
raise ValueError("All variables must be binary for BinaryGibbsMetropolis")
558578

559-
super().__init__(vars, [model.compile_logp()], rng=rng)
579+
super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
560580

561581
def reset_tuning(self):
562582
# There are no tuning parameters in this step method.
@@ -638,7 +658,14 @@ class CategoricalGibbsMetropolis(ArrayStep):
638658
_state_class = CategoricalGibbsMetropolisState
639659

640660
def __init__(
641-
self, vars, proposal="uniform", order="random", model=None, rng: RandomGenerator = None
661+
self,
662+
vars,
663+
*,
664+
proposal="uniform",
665+
order="random",
666+
model=None,
667+
rng: RandomGenerator = None,
668+
blocked: bool = True,
642669
):
643670
model = pm.modelcontext(model)
644671

@@ -693,7 +720,7 @@ def __init__(
693720
# that indicates whether a draw was done in a tuning phase.
694721
self.tune = True
695722

696-
super().__init__(vars, [model.compile_logp()], rng=rng)
723+
super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
697724

698725
def reset_tuning(self):
699726
# There are no tuning parameters in this step method.
@@ -858,6 +885,7 @@ class DEMetropolis(PopulationArrayStepShared):
858885
def __init__(
859886
self,
860887
vars=None,
888+
*,
861889
S=None,
862890
proposal_dist=None,
863891
lamb=None,
@@ -867,7 +895,7 @@ def __init__(
867895
model=None,
868896
mode=None,
869897
rng=None,
870-
**kwargs,
898+
blocked: bool = True,
871899
):
872900
model = pm.modelcontext(model)
873901
initial_values = model.initial_point()
@@ -902,7 +930,7 @@ def __init__(
902930

903931
shared = pm.make_shared_replacements(initial_values, vars, model)
904932
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
905-
super().__init__(vars, shared, rng=rng)
933+
super().__init__(vars, shared, blocked=blocked, rng=rng)
906934

907935
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
908936
point_map_info = q0.point_map_info
@@ -1025,6 +1053,7 @@ class DEMetropolisZ(ArrayStepShared):
10251053
def __init__(
10261054
self,
10271055
vars=None,
1056+
*,
10281057
S=None,
10291058
proposal_dist=None,
10301059
lamb=None,
@@ -1035,7 +1064,7 @@ def __init__(
10351064
model=None,
10361065
mode=None,
10371066
rng=None,
1038-
**kwargs,
1067+
blocked: bool = True,
10391068
):
10401069
model = pm.modelcontext(model)
10411070
initial_values = model.initial_point()
@@ -1082,7 +1111,7 @@ def __init__(
10821111

10831112
shared = pm.make_shared_replacements(initial_values, vars, model)
10841113
self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared)
1085-
super().__init__(vars, shared, rng=rng)
1114+
super().__init__(vars, shared, blocked=blocked, rng=rng)
10861115

10871116
def reset_tuning(self):
10881117
"""Reset the tuned sampler parameters and history to their initial values."""

pymc/step_methods/slicer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,15 @@ class Slice(ArrayStepShared):
7676
_state_class = SliceState
7777

7878
def __init__(
79-
self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs
79+
self,
80+
vars=None,
81+
*,
82+
w=1.0,
83+
tune=True,
84+
model=None,
85+
iter_limit=np.inf,
86+
rng=None,
87+
blocked: bool = False, # Could be true since tuning is independent across dims?
8088
):
8189
model = modelcontext(model)
8290
self.w = np.asarray(w).copy()
@@ -97,7 +105,7 @@ def __init__(
97105
self.logp = compile_pymc([raveled_inp], logp)
98106
self.logp.trust_input = True
99107

100-
super().__init__(vars, shared, rng=rng)
108+
super().__init__(vars, shared, blocked=blocked, rng=rng)
101109

102110
def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]:
103111
# The arguments are determined by the list passed via `super().__init__(..., fs, ...)`

tests/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def arbitrary_det(value):
7878

7979
def simple_init():
8080
start, model, moments = simple_model()
81-
step = Metropolis(model.value_vars, np.diag([1.0]), model=model)
81+
step = Metropolis(model.value_vars, S=np.diag([1.0]), model=model)
8282
return model, start, step, moments
8383

8484

0 commit comments

Comments
 (0)