Skip to content

Commit ec6e4c0

Browse files
committed
Fix metropolis.py type hints
1 parent 8ccf28f commit ec6e4c0

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pymc/step_methods/metropolis.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"MultivariateNormalProposal",
6060
]
6161

62-
from pymc.util import get_value_vars_from_user_vars
62+
from pymc.util import RandomGenerator, get_value_vars_from_user_vars
6363

6464
# Available proposal distributions for Metropolis
6565

@@ -302,7 +302,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
302302
accept_rate = self.delta_logp(q, q0d)
303303
q, accepted = metrop_select(accept_rate, q, q0d, rng=self.rng)
304304
self.accept_rate_iter = accept_rate
305-
self.accepted_iter = accepted
305+
self.accepted_iter[0] = accepted
306306
self.accepted_sum += accepted
307307

308308
self.steps_until_tune -= 1
@@ -622,14 +622,16 @@ class CategoricalGibbsMetropolis(ArrayStep):
622622

623623
_state_class = CategoricalGibbsMetropolisState
624624

625-
def __init__(self, vars, proposal="uniform", order="random", model=None, rng=None):
625+
def __init__(
626+
self, vars, proposal="uniform", order="random", model=None, rng: RandomGenerator = None
627+
):
626628
model = pm.modelcontext(model)
627629

628630
vars = get_value_vars_from_user_vars(vars, model)
629631

630632
initial_point = model.initial_point()
631633

632-
dimcats = []
634+
dimcats: list[tuple[int, int]] = []
633635
# The above variable is a list of pairs (aggregate dimension, number
634636
# of categories). For example, if vars = [x, y] with x being a 2-D
635637
# variable with M categories and y being a 3-D variable with N
@@ -665,10 +667,10 @@ def __init__(self, vars, proposal="uniform", order="random", model=None, rng=Non
665667
self.dimcats = [dimcats[j] for j in order]
666668

667669
if proposal == "uniform":
668-
self.astep = self.astep_unif
670+
self.astep = self.astep_unif # type: ignore[assignment]
669671
elif proposal == "proportional":
670672
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
671-
self.astep = self.astep_prop
673+
self.astep = self.astep_prop # type: ignore[assignment]
672674
else:
673675
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
674676

0 commit comments

Comments
 (0)