diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index d825c8857..7e1d32a4d 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -383,6 +383,13 @@ class BinaryMetropolisState(StepMethodState): class BinaryMetropolis(ArrayStep): """Metropolis-Hastings optimized for binary variables. + Unlike BinaryGibbsMetropolis, this step sampler proposes an update for all variable dimensions at once. + + This will perform a single logp evaluation per step, at the expense of a lower acceptance rate when + the posteriors of the binary variables are highly correlated. + + The BinaryGibbsMetropolis (not this one) is the default step sampler for binary variables + Parameters ---------- vars: list @@ -489,6 +496,14 @@ class BinaryGibbsMetropolisState(StepMethodState): class BinaryGibbsMetropolis(ArrayStep): """A Metropolis-within-Gibbs step method optimized for binary variables. + Unlike BinaryMetropolis, this step sampler proposes a variable dimension update at a time. + + This will increase acceptance rate when the posteriors of the binary variables are highly correlated, + at the expense of doing more logp evaluations per step. + + This is the default step sampler for binary variables. + + Parameters ---------- vars: list diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index a73538a61..63262759c 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -356,22 +356,21 @@ def test_step_continuous(self, step_fn, draws): class TestRVsAssignmentMetropolis(RVsAssignmentStepsTester): @pytest.mark.parametrize( - "step, step_kwargs", + "step", [ - (BinaryGibbsMetropolis, {}), - (CategoricalGibbsMetropolis, {}), + BinaryMetropolis, + BinaryGibbsMetropolis, + CategoricalGibbsMetropolis, ], ) - def test_discrete_steps(self, step, step_kwargs): + def test_discrete_steps(self, step): with pm.Model() as m: d1 = pm.Bernoulli("d1", p=0.5) d2 = pm.Bernoulli("d2", p=0.5) with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - assert [m.rvs_to_values[d1]] == step([d1], **step_kwargs).vars - assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set( - step([d1, d2], **step_kwargs).vars - ) + assert [m.rvs_to_values[d1]] == step([d1]).vars + assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars) @pytest.mark.parametrize( "step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})]