Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions numpyro/infer/mclmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple

import jax

from numpyro.infer.mcmc import MCMCKernel
from numpyro.infer.util import initialize_model
from numpyro.util import identity

try:
import blackjax
from blackjax.mcmc.integrators import IntegratorState
from blackjax.util import pytree_size

_BLACKJAX_AVAILABLE = True
except ImportError:
_BLACKJAX_AVAILABLE = False
blackjax = None
IntegratorState = None
pytree_size = None

FullState = namedtuple(
"FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"]
)


class MCLMC(MCMCKernel):
"""
Microcanonical Langevin Monte Carlo (MCLMC) kernel.

MCLMC is a gradient-based MCMC algorithm that uses Hamiltonian dynamics
on an extended state space. It requires the `blackjax` package.

**References:**

1. *Microcanonical Hamiltonian Monte Carlo*,
Jakob Robnik, G. Bruno De Luca, Eva Silverstein, Uroš Seljak
https://arxiv.org/abs/2212.08549

.. note:: The model must have at least 2 latent dimensions for MCLMC to work
(this is a limitation of the blackjax implementation).

:param model: Python callable containing Pyro :mod:`~numpyro.primitives`.
:param float desired_energy_var: Target energy variance for step size and
trajectory length tuning. Smaller values lead to more conservative
step sizes. Defaults to 5e-4.
:param bool diagonal_preconditioning: Whether to use diagonal preconditioning
for the mass matrix. Defaults to True.
"""

def __init__(
self,
model=None,
desired_energy_var=5e-4,
diagonal_preconditioning=True,
):
if not _BLACKJAX_AVAILABLE:
raise ImportError(
"MCLMC requires the 'blackjax' package. "
"Please install it with: pip install blackjax"
)
if model is None:
raise ValueError("Model must be specified for MCLMC")
self._model = model
self._diagonal_preconditioning = diagonal_preconditioning
self._desired_energy_var = desired_energy_var
self._init_fn = None
self._sample_fn = None
self._postprocess_fn = None

@property
def model(self):
return self._model

@property
def sample_field(self):
return "position"

@property
def default_fields(self):
return (self.sample_field,)

def get_diagnostics_str(self, state):
"""
Return a diagnostics string for the progress bar.
"""
return "step_size={:.2e}, L={:.2e}".format(
self.adapt_state.step_size, self.adapt_state.L
)

def postprocess_fn(self, args, kwargs):
"""
Get a function that transforms unconstrained values at sample sites to values
constrained to the site's support, in addition to returning deterministic
sites in the model.

:param args: Arguments to the model.
:param kwargs: Keyword arguments to the model.
"""
if self._postprocess_fn is None:
return identity
return self._postprocess_fn(*args, **kwargs)

def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
"""
Initialize the MCLMC kernel.

:param rng_key: Random number generator key
:param num_warmup: Number of warmup steps
:param init_params: Initial parameters
:param model_args: Model arguments
:param model_kwargs: Model keyword arguments
:return: Initial state
"""

init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split(
rng_key, 4
)

init_params, potential_fn_gen, postprocess_fn, _ = initialize_model(
init_model_key,
self._model,
model_args=model_args,
model_kwargs=model_kwargs,
dynamic_args=True,
)
self._postprocess_fn = postprocess_fn

def logdensity_fn(position):
return -potential_fn_gen(*model_args, **model_kwargs)(position)

initial_position = init_params.z
self.logdensity_fn = logdensity_fn

sampler_state = blackjax.mcmc.mclmc.init(
position=initial_position,
logdensity_fn=self.logdensity_fn,
rng_key=init_state_key,
)

def kernel(inverse_mass_matrix):
return blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
inverse_mass_matrix=inverse_mass_matrix,
)

self.dim = pytree_size(initial_position)

# num_steps is a dummy param here (used for tuning fractions)
num_tuning_steps = 100
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
_,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_tuning_steps,
state=sampler_state,
rng_key=rng_key_tune,
diagonal_preconditioning=self._diagonal_preconditioning,
frac_tune3=num_warmup / (3 * num_tuning_steps),
frac_tune2=num_warmup / (3 * num_tuning_steps),
frac_tune1=num_warmup / (3 * num_tuning_steps),
desired_energy_var=self._desired_energy_var,
)

self.adapt_state = blackjax_mclmc_sampler_params

return FullState(
blackjax_state_after_tuning.position,
blackjax_state_after_tuning.momentum,
blackjax_state_after_tuning.logdensity,
blackjax_state_after_tuning.logdensity_grad,
run_key,
)

def sample(self, state, model_args, model_kwargs):
"""
Run MCLMC from the given state and return the resulting state.

:param state: Current state
:param model_args: Model arguments
:param model_kwargs: Model keyword arguments
:return: Next state after running MCLMC
"""

mclmc_state = IntegratorState(
state.position, state.momentum, state.logdensity, state.logdensity_grad
)
rng_key, rng_key_sample = jax.random.split(state.rng_key, 2)

kernel = blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=self.logdensity_fn,
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
inverse_mass_matrix=self.adapt_state.inverse_mass_matrix,
)

new_state, info = kernel(
rng_key=rng_key_sample,
state=mclmc_state,
step_size=self.adapt_state.step_size,
L=self.adapt_state.L,
)

return FullState(
new_state.position,
new_state.momentum,
new_state.logdensity,
new_state.logdensity_grad,
rng_key,
)

def __getstate__(self):
state = self.__dict__.copy()
state["_postprocess_fn"] = None
return state
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"scikit-learn",
"scipy>=1.9",
"ty>=0.0.4",
"blackjax>=1.3",
],
"dev": [
"dm-haiku>=0.0.14",
Expand Down
155 changes: 155 additions & 0 deletions test/infer/test_mclmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from numpy.testing import assert_allclose
import pytest

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC
from numpyro.infer.mclmc import MCLMC


def test_mclmc_model_required():
"""Test that ValueError is raised when model is None."""
with pytest.raises(ValueError, match="Model must be specified"):
MCLMC(model=None)


def test_mclmc_blackjax_not_installed(monkeypatch):
"""Test that ImportError is raised with informative message when blackjax is not installed."""
import numpyro.infer.mclmc as mclmc_module

# Temporarily set _BLACKJAX_AVAILABLE to False
monkeypatch.setattr(mclmc_module, "_BLACKJAX_AVAILABLE", False)

def dummy_model():
numpyro.sample("x", dist.Normal(0, 1))

with pytest.raises(ImportError, match="MCLMC requires the 'blackjax' package"):
MCLMC(model=dummy_model)


def test_mclmc_normal():
"""Test MCLMC with a 2D normal distribution.

Note: MCLMC requires at least 2 dimensions (blackjax limitation).
"""
true_mean = jnp.array([1.0, 2.0])
true_std = jnp.array([0.5, 1.0])
num_warmup, num_samples = 1000, 2000

def model():
numpyro.sample("x", dist.Normal(true_mean, true_std).to_event(1))

kernel = MCLMC(model=model)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
progress_bar=False,
)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples()

assert "x" in samples
assert samples["x"].shape == (num_samples, 2)
assert_allclose(jnp.mean(samples["x"], axis=0), true_mean, atol=0.1)
assert_allclose(jnp.std(samples["x"], axis=0), true_std, atol=0.2)


def test_mclmc_gaussian_2d():
"""Test MCLMC with a 2D Gaussian model with observation."""
num_warmup, num_samples = 1000, 1000

def model():
x = numpyro.sample("x", dist.Normal(0.0, 1.0))
y = numpyro.sample("y", dist.Normal(0.0, 1.0))
numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array(0.0))

kernel = MCLMC(
model=model,
diagonal_preconditioning=True,
desired_energy_var=5e-4,
)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
progress_bar=False,
)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples()

assert "x" in samples
assert "y" in samples
assert samples["x"].shape == (num_samples,)
assert samples["y"].shape == (num_samples,)
# With obs=0, x+y should be close to 0, so means should be near 0
assert_allclose(jnp.mean(samples["x"]) + jnp.mean(samples["y"]), 0.0, atol=0.2)


def test_mclmc_logistic_regression():
"""Test MCLMC with a logistic regression model.

Note: MCLMC currently doesn't pass model_args, so we use a closure pattern.
"""
N, dim = 1000, 3
num_warmup, num_samples = 1000, 2000

key1, key2, key3 = random.split(random.PRNGKey(0), 3)
data = random.normal(key1, (N, dim))
true_coefs = jnp.arange(1.0, dim + 1.0)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(key2)

# Use closure pattern since MCLMC doesn't pass model_args
def model():
coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.sum(coefs * data, axis=-1)
numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

kernel = MCLMC(model=model)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
progress_bar=False,
)
mcmc.run(key3)
samples = mcmc.get_samples()

assert "coefs" in samples
assert samples["coefs"].shape == (num_samples, dim)
assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.5)


def test_mclmc_sample_shape():
"""Test that MCLMC produces samples with expected shapes."""
num_warmup, num_samples = 500, 500

def model():
numpyro.sample("a", dist.Normal(0, 1))
numpyro.sample("b", dist.Normal(0, 1).expand([3]))
numpyro.sample("c", dist.Normal(0, 1).expand([2, 4]))

kernel = MCLMC(model=model)
mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
progress_bar=False,
)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples()

assert samples["a"].shape == (num_samples,)
assert samples["b"].shape == (num_samples, 3)
assert samples["c"].shape == (num_samples, 2, 4)
Loading