From 2cd8360e181200041701f378cf83678cf9f92352 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 10 Sep 2024 18:25:52 +0800 Subject: [PATCH 1/4] Allow access to different nutpie backends via pip-style syntax --- pymc/sampling/mcmc.py | 42 +++++++--- tests/sampling/test_mcmc_external.py | 116 +++++++++++++++++++-------- 2 files changed, 114 insertions(+), 44 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4ee79607b..bcd66faaf 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -17,17 +17,13 @@ import contextlib import logging import pickle +import re import sys import time import warnings from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import ( - Any, - Literal, - TypeAlias, - overload, -) +from typing import Any, Literal, TypeAlias, cast, get_args, overload import numpy as np import pytensor.gradient as tg @@ -86,6 +82,13 @@ Step: TypeAlias = BlockedStep | CompoundStep +ExternalNutsSampler = ["nutpie", "numpyro", "blackjax"] +NutsSampler = Literal["pymc"] | ExternalNutsSampler + +NutpieBackend = Literal["numba", "jax"] +NUTPIE_BACKENDS = get_args(NutpieBackend) +NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba") + class SamplingIteratorCallback(Protocol): """Signature of the callable that may be passed to `pm.sample(callable=...)`.""" @@ -262,7 +265,7 @@ def all_continuous(vars): def _sample_external_nuts( - sampler: Literal["nutpie", "numpyro", "blackjax"], + sampler: ExternalNutsSampler, draws: int, tune: int, chains: int, @@ -280,7 +283,7 @@ def _sample_external_nuts( if nuts_sampler_kwargs is None: nuts_sampler_kwargs = {} - if sampler == "nutpie": + if sampler.startswith("nutpie"): try: import nutpie except ImportError as err: @@ -313,6 +316,23 @@ def _sample_external_nuts( model, **compile_kwargs, ) + + def extract_backend(string: str) -> NutpieBackend: + match = re.search(r"(?<=\[)[^\]]+(?=\])", string) + if match is None: + return NUTPIE_DEFAULT_BACKEND + result = cast(NutpieBackend, match.group(0)) + if result not in NUTPIE_BACKENDS: + last_option = f"{NUTPIE_BACKENDS[-1]}" + expected = ( + ", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"' + ) + raise ValueError(f'Expected one of {expected}; found "{result}"') + return result + + backend = extract_backend(sampler) + compiled_model = nutpie.compile_pymc_model(model, backend=backend) + t_start = time.time() idata = nutpie.sample( compiled_model, @@ -396,7 +416,7 @@ def sample( progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, - nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", + nuts_sampler: NutsSampler = "pymc", initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, @@ -427,7 +447,7 @@ def sample( progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, - nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", + nuts_sampler: NutsSampler = "pymc", initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, @@ -458,7 +478,7 @@ def sample( progressbar_theme: Theme | None = default_progress_theme, step=None, var_names: Sequence[str] | None = None, - nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", + nuts_sampler: NutsSampler = "pymc", initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 3305d018f..8bf59c653 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -16,45 +16,19 @@ import numpy.testing as npt import pytest -from pymc import Data, Model, Normal, sample +from pymc import Data, Model, Normal, modelcontext, sample -@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) -def test_external_nuts_sampler(recwarn, nuts_sampler): - if nuts_sampler != "pymc": - pytest.importorskip(nuts_sampler) - - with Model(): - x = Normal("x", 100, 5) - y = Data("y", [1, 2, 3, 4]) - Data("z", [100, 190, 310, 405]) - - Normal("L", mu=x, sigma=0.1, observed=y) - - kwargs = { - "nuts_sampler": nuts_sampler, - "random_seed": 123, - "chains": 2, - "tune": 500, - "draws": 500, - "progressbar": False, - "initvals": {"x": 0.0}, - } - - idata1 = sample(**kwargs) - idata2 = sample(**kwargs) +def check_external_sampler_output(warns, idata1, idata2, sample_kwargs): + nuts_sampler = sample_kwargs["nuts_sampler"] + reference_kwargs = sample_kwargs.copy() + reference_kwargs["nuts_sampler"] = "pymc" - reference_kwargs = kwargs.copy() - reference_kwargs["nuts_sampler"] = "pymc" + with modelcontext(None): idata_reference = sample(**reference_kwargs) - warns = { - (warn.category, warn.message.args[0]) - for warn in recwarn - if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) - } expected = set() - if nuts_sampler == "nutpie": + if nuts_sampler.startswith("nutpie"): expected.add( ( UserWarning, @@ -74,7 +48,83 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() +@pytest.fixture +def pymc_model(): + with Model() as m: + x = Normal("x", 100, 5) + y = Data("y", [1, 2, 3, 4]) + Data("z", [100, 190, 310, 405]) + + Normal("L", mu=x, sigma=0.1, observed=y) + + return m + + +@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) +def test_external_nuts_sampler(pymc_model, recwarn, nuts_sampler): + if nuts_sampler != "pymc": + pytest.importorskip(nuts_sampler) + + sample_kwargs = dict( + nuts_sampler=nuts_sampler, + random_seed=123, + chains=2, + tune=500, + draws=500, + progressbar=False, + initvals={"x": 0.0}, + ) + + with pymc_model: + idata1 = sample(**sample_kwargs) + idata2 = sample(**sample_kwargs) + + warns = { + (warn.category, warn.message.args[0]) + for warn in recwarn + if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) + } + + check_external_sampler_output(warns, idata1, idata2, sample_kwargs) + + +@pytest.mark.parametrize("backend", ["numba", "jax"], ids=["numba", "jax"]) +def test_numba_backend_options(pymc_model, recwarn, backend): + pytest.importorskip("nutpie") + pytest.importorskip(backend) + + sample_kwargs = dict( + nuts_sampler=f"nutpie[{backend}]", + random_seed=123, + chains=2, + tune=500, + draws=500, + progressbar=False, + initvals={"x": 0.0}, + ) + + with pymc_model: + idata1 = sample(**sample_kwargs) + idata2 = sample(**sample_kwargs) + + warns = { + (warn.category, warn.message.args[0]) + for warn in recwarn + if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) + } + + check_external_sampler_output(warns, idata1, idata2, sample_kwargs) + + +def test_invalid_nutpie_backend_raises(pymc_model): + with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): + with pymc_model: + sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500) + + def test_step_args(): + pytest.importorskip("numpyro") + with Model() as model: a = Normal("a") idata = sample( From 6239b3d642197213af291869ffde8c22d334b116 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 10 Sep 2024 18:30:22 +0800 Subject: [PATCH 2/4] Update pm.sample docstring --- pymc/sampling/mcmc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index bcd66faaf..39d02496c 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -537,8 +537,10 @@ def sample( method will be used, if appropriate to the model. var_names : list of str, optional Names of variables to be stored in the trace. Defaults to all free variables and deterministics. - nuts_sampler : str - Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. + nuts_sampler : str, default "pymc" + Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. In addition, the compilation + backend for the chosen sampler can be set using square brackets, if available. For example, "nutpie[jax]" will + use the JAX backend for the nutpie sampler. Currently, "nutpie[jax]" and "nutpie[numba]" are allowed. This requires the chosen sampler to be installed. All samplers, except "pymc", require the full model to be continuous. blas_cores: int or "auto" or None, default = "auto" From ca04890a554029cfc7819667959a0d900131f04c Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 10 Sep 2024 18:54:44 +0800 Subject: [PATCH 3/4] Make mypy happy --- pymc/sampling/jax.py | 4 +++- pymc/sampling/mcmc.py | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 43e1baa87..c8111a2cd 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -68,6 +68,8 @@ "sample_numpyro_nuts", ) +JaxNutsSampler = Literal["numpyro", "blackjax"] + @jax_funcify.register(Assert) @jax_funcify.register(CheckParameterValue) @@ -486,7 +488,7 @@ def sample_jax_nuts( postprocessing_chunks=None, idata_kwargs: dict | None = None, compute_convergence_checks: bool = True, - nuts_sampler: Literal["numpyro", "blackjax"], + nuts_sampler: JaxNutsSampler, ) -> az.InferenceData: """ Draw samples from the posterior using a jax NUTS method. diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 39d02496c..32b96c198 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -82,10 +82,11 @@ Step: TypeAlias = BlockedStep | CompoundStep -ExternalNutsSampler = ["nutpie", "numpyro", "blackjax"] +ExternalNutsSampler = Literal["nutpie", "numpyro", "blackjax"] NutsSampler = Literal["pymc"] | ExternalNutsSampler - NutpieBackend = Literal["numba", "jax"] + + NUTPIE_BACKENDS = get_args(NutpieBackend) NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba") @@ -381,6 +382,10 @@ def extract_backend(string: str) -> NutpieBackend: elif sampler in ("numpyro", "blackjax"): import pymc.sampling.jax as pymc_jax + from pymc.sampling.jax import JaxNutsSampler + + sampler = cast(JaxNutsSampler, sampler) + idata = pymc_jax.sample_jax_nuts( draws=draws, tune=tune, From a5b3241351c51ca13a04f9874f5e3c9cc0aac660 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Tue, 10 Sep 2024 19:17:46 +0800 Subject: [PATCH 4/4] `importorskip` nutpie backend test --- tests/sampling/test_mcmc_external.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 8bf59c653..5f9f0fb50 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -117,6 +117,7 @@ def test_numba_backend_options(pymc_model, recwarn, backend): def test_invalid_nutpie_backend_raises(pymc_model): + pytest.importorskip("nutpie") with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): with pymc_model: sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500)