Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,13 @@ class BinaryMetropolisState(StepMethodState):
class BinaryMetropolis(ArrayStep):
"""Metropolis-Hastings optimized for binary variables.

Unlike BinaryGibbsMetropolis, this step sampler proposes an update for all variable dimensions at once.

This will perform a single logp evaluation per step, at the expense of a lower acceptance rate when
the posteriors of the binary variables are highly correlated.

The BinaryGibbsMetropolis (not this one) is the default step sampler for binary variables

Parameters
----------
vars: list
Expand Down Expand Up @@ -489,6 +496,14 @@ class BinaryGibbsMetropolisState(StepMethodState):
class BinaryGibbsMetropolis(ArrayStep):
"""A Metropolis-within-Gibbs step method optimized for binary variables.

Unlike BinaryMetropolis, this step sampler proposes a variable dimension update at a time.

This will increase acceptance rate when the posteriors of the binary variables are highly correlated,
at the expense of doing more logp evaluations per step.

This is the default step sampler for binary variables.


Parameters
----------
vars: list
Expand Down
15 changes: 7 additions & 8 deletions tests/step_methods/test_metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,22 +356,21 @@ def test_step_continuous(self, step_fn, draws):

class TestRVsAssignmentMetropolis(RVsAssignmentStepsTester):
@pytest.mark.parametrize(
"step, step_kwargs",
"step",
[
(BinaryGibbsMetropolis, {}),
(CategoricalGibbsMetropolis, {}),
BinaryMetropolis,
BinaryGibbsMetropolis,
CategoricalGibbsMetropolis,
],
)
def test_discrete_steps(self, step, step_kwargs):
def test_discrete_steps(self, step):
with pm.Model() as m:
d1 = pm.Bernoulli("d1", p=0.5)
d2 = pm.Bernoulli("d2", p=0.5)

with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
assert [m.rvs_to_values[d1]] == step([d1], **step_kwargs).vars
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(
step([d1, d2], **step_kwargs).vars
)
assert [m.rvs_to_values[d1]] == step([d1]).vars
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars)

@pytest.mark.parametrize(
"step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})]
Expand Down
Loading