Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
"sample_numpyro_nuts",
)

JaxNutsSampler = Literal["numpyro", "blackjax"]


@jax_funcify.register(Assert)
@jax_funcify.register(CheckParameterValue)
Expand Down Expand Up @@ -524,7 +526,7 @@ def sample_jax_nuts(
postprocessing_chunks=None,
idata_kwargs: dict | None = None,
compute_convergence_checks: bool = True,
nuts_sampler: Literal["numpyro", "blackjax"],
nuts_sampler: JaxNutsSampler,
) -> az.InferenceData:
"""
Draw samples from the posterior using a jax NUTS method.
Expand Down
55 changes: 41 additions & 14 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,13 @@
import contextlib
import logging
import pickle
import re
import sys
import time
import warnings

from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (
Any,
Literal,
TypeAlias,
cast,
overload,
)
from typing import Any, Literal, TypeAlias, cast, get_args, overload

import numpy as np
import pytensor.gradient as tg
Expand Down Expand Up @@ -90,6 +85,14 @@

Step: TypeAlias = BlockedStep | CompoundStep

ExternalNutsSampler = Literal["nutpie", "numpyro", "blackjax"]
NutsSampler = Literal["pymc"] | ExternalNutsSampler
NutpieBackend = Literal["numba", "jax"]


NUTPIE_BACKENDS = get_args(NutpieBackend)
NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba")


class SamplingIteratorCallback(Protocol):
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""
Expand Down Expand Up @@ -291,7 +294,7 @@ def all_continuous(vars):


def _sample_external_nuts(
sampler: Literal["nutpie", "numpyro", "blackjax"],
sampler: ExternalNutsSampler,
draws: int,
tune: int,
chains: int,
Expand All @@ -309,7 +312,7 @@ def _sample_external_nuts(
if nuts_sampler_kwargs is None:
nuts_sampler_kwargs = {}

if sampler == "nutpie":
if sampler.startswith("nutpie"):
try:
import nutpie
except ImportError as err:
Expand Down Expand Up @@ -340,6 +343,24 @@ def _sample_external_nuts(
var_names=var_names,
**compile_kwargs,
)

def extract_backend(string: str) -> NutpieBackend:
match = re.search(r"(?<=\[)[^\]]+(?=\])", string)
if string == "nutpie":
return NUTPIE_DEFAULT_BACKEND
elif match is None:
raise ValueError(
f"Could not parse nutpie backend. Found {string!r}, expected format 'nutpie[backend]'"
)

result = cast(NutpieBackend, match.group(0))
if result not in NUTPIE_BACKENDS:
raise ValueError(f'Expected one of {NUTPIE_BACKENDS}; found "{result}"')
return result

backend = extract_backend(sampler)
compiled_model = nutpie.compile_pymc_model(model, backend=backend)

t_start = time.time()
idata = nutpie.sample(
compiled_model,
Expand Down Expand Up @@ -388,6 +409,10 @@ def _sample_external_nuts(
elif sampler in ("numpyro", "blackjax"):
import pymc.sampling.jax as pymc_jax

from pymc.sampling.jax import JaxNutsSampler

sampler = cast(JaxNutsSampler, sampler)

idata = pymc_jax.sample_jax_nuts(
draws=draws,
tune=tune,
Expand Down Expand Up @@ -423,7 +448,7 @@ def sample(
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: NutsSampler = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -455,7 +480,7 @@ def sample(
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: NutsSampler = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -487,7 +512,7 @@ def sample(
progressbar_theme: Theme | None = None,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
nuts_sampler: NutsSampler = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
Expand Down Expand Up @@ -552,8 +577,10 @@ def sample(
method will be used, if appropriate to the model.
var_names : list of str, optional
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
nuts_sampler : str
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
nuts_sampler : str, default "pymc"
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. In addition, the compilation
backend for the chosen sampler can be set using square brackets, if available. For example, "nutpie[jax]" will
use the JAX backend for the nutpie sampler. Currently, "nutpie[jax]" and "nutpie[numba]" are allowed.
This requires the chosen sampler to be installed.
All samplers, except "pymc", require the full model to be continuous.
blas_cores: int or "auto" or None, default = "auto"
Expand Down
118 changes: 85 additions & 33 deletions tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,19 @@
import pytest
import xarray as xr

from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample
from pymc import Data, Deterministic, HalfNormal, Model, Normal, modelcontext, sample


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
def test_external_nuts_sampler(recwarn, nuts_sampler):
if nuts_sampler != "pymc":
pytest.importorskip(nuts_sampler)

with Model():
x = Normal("x", 100, 5)
y = Data("y", [1, 2, 3, 4])
Data("z", [100, 190, 310, 405])

Normal("L", mu=x, sigma=0.1, observed=y)

kwargs = {
"nuts_sampler": nuts_sampler,
"random_seed": 123,
"chains": 2,
"tune": 500,
"draws": 500,
"progressbar": False,
"initvals": {"x": 0.0},
}
def check_external_sampler_output(warns, idata1, idata2, sample_kwargs):
nuts_sampler = sample_kwargs["nuts_sampler"]
reference_kwargs = sample_kwargs.copy()
reference_kwargs["nuts_sampler"] = "pymc"

idata1 = sample(**kwargs)
idata2 = sample(**kwargs)

reference_kwargs = kwargs.copy()
reference_kwargs["nuts_sampler"] = "pymc"
with modelcontext(None):
idata_reference = sample(**reference_kwargs)

warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
}
expected = set()
if nuts_sampler == "nutpie":
if nuts_sampler.startswith("nutpie"):
expected.add(
(
UserWarning,
Expand All @@ -75,6 +49,84 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys()


@pytest.fixture
def pymc_model():
with Model() as m:
x = Normal("x", 100, 5)
y = Data("y", [1, 2, 3, 4])
Data("z", [100, 190, 310, 405])

Normal("L", mu=x, sigma=0.1, observed=y)

return m


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "nutpie[jax]", "blackjax", "numpyro"])
def test_external_nuts_sampler(pymc_model, recwarn, nuts_sampler):
if nuts_sampler != "pymc":
pytest.importorskip(nuts_sampler)

sample_kwargs = {
"nuts_sampler": nuts_sampler,
"random_seed": 123,
"chains": 2,
"tune": 500,
"draws": 500,
"progressbar": False,
"initvals": {"x": 0.0},
}

with pymc_model:
idata1 = sample(**sample_kwargs)
idata2 = sample(**sample_kwargs)

warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
}

check_external_sampler_output(warns, idata1, idata2, sample_kwargs)


@pytest.mark.parametrize("backend", ["numba", "jax"], ids=["numba", "jax"])
def test_numba_backend_options(pymc_model, recwarn, backend):
pytest.importorskip("nutpie")
pytest.importorskip(backend)

sample_kwargs = {
"nuts_sampler": f"nutpie[{backend}]",
"random_seed": 123,
"chains": 2,
"tune": 500,
"draws": 500,
"progressbar": False,
"initvals": {"x": 0.0},
}

with pymc_model:
idata1 = sample(**sample_kwargs)
idata2 = sample(**sample_kwargs)

warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
}

check_external_sampler_output(warns, idata1, idata2, sample_kwargs)


def test_invalid_nutpie_backend_raises(pymc_model):
pytest.importorskip("nutpie")
with pytest.raises(
ValueError,
match='Could not parse nutpie backend. Expected one of "numba" or "jax"; found "invalid"',
):
with pymc_model:
sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500)


def test_step_args():
pytest.importorskip("numpyro")

Expand Down
Loading