Skip to content

Commit 2cd8360

Browse files
Allow access to different nutpie backends via pip-style syntax
1 parent 5352798 commit 2cd8360

File tree

2 files changed

+114
-44
lines changed

2 files changed

+114
-44
lines changed

pymc/sampling/mcmc.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,13 @@
1717
import contextlib
1818
import logging
1919
import pickle
20+
import re
2021
import sys
2122
import time
2223
import warnings
2324

2425
from collections.abc import Callable, Iterator, Mapping, Sequence
25-
from typing import (
26-
Any,
27-
Literal,
28-
TypeAlias,
29-
overload,
30-
)
26+
from typing import Any, Literal, TypeAlias, cast, get_args, overload
3127

3228
import numpy as np
3329
import pytensor.gradient as tg
@@ -86,6 +82,13 @@
8682

8783
Step: TypeAlias = BlockedStep | CompoundStep
8884

85+
ExternalNutsSampler = ["nutpie", "numpyro", "blackjax"]
86+
NutsSampler = Literal["pymc"] | ExternalNutsSampler
87+
88+
NutpieBackend = Literal["numba", "jax"]
89+
NUTPIE_BACKENDS = get_args(NutpieBackend)
90+
NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba")
91+
8992

9093
class SamplingIteratorCallback(Protocol):
9194
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""
@@ -262,7 +265,7 @@ def all_continuous(vars):
262265

263266

264267
def _sample_external_nuts(
265-
sampler: Literal["nutpie", "numpyro", "blackjax"],
268+
sampler: ExternalNutsSampler,
266269
draws: int,
267270
tune: int,
268271
chains: int,
@@ -280,7 +283,7 @@ def _sample_external_nuts(
280283
if nuts_sampler_kwargs is None:
281284
nuts_sampler_kwargs = {}
282285

283-
if sampler == "nutpie":
286+
if sampler.startswith("nutpie"):
284287
try:
285288
import nutpie
286289
except ImportError as err:
@@ -313,6 +316,23 @@ def _sample_external_nuts(
313316
model,
314317
**compile_kwargs,
315318
)
319+
320+
def extract_backend(string: str) -> NutpieBackend:
321+
match = re.search(r"(?<=\[)[^\]]+(?=\])", string)
322+
if match is None:
323+
return NUTPIE_DEFAULT_BACKEND
324+
result = cast(NutpieBackend, match.group(0))
325+
if result not in NUTPIE_BACKENDS:
326+
last_option = f"{NUTPIE_BACKENDS[-1]}"
327+
expected = (
328+
", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"'
329+
)
330+
raise ValueError(f'Expected one of {expected}; found "{result}"')
331+
return result
332+
333+
backend = extract_backend(sampler)
334+
compiled_model = nutpie.compile_pymc_model(model, backend=backend)
335+
316336
t_start = time.time()
317337
idata = nutpie.sample(
318338
compiled_model,
@@ -396,7 +416,7 @@ def sample(
396416
progressbar_theme: Theme | None = default_progress_theme,
397417
step=None,
398418
var_names: Sequence[str] | None = None,
399-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
419+
nuts_sampler: NutsSampler = "pymc",
400420
initvals: StartDict | Sequence[StartDict | None] | None = None,
401421
init: str = "auto",
402422
jitter_max_retries: int = 10,
@@ -427,7 +447,7 @@ def sample(
427447
progressbar_theme: Theme | None = default_progress_theme,
428448
step=None,
429449
var_names: Sequence[str] | None = None,
430-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
450+
nuts_sampler: NutsSampler = "pymc",
431451
initvals: StartDict | Sequence[StartDict | None] | None = None,
432452
init: str = "auto",
433453
jitter_max_retries: int = 10,
@@ -458,7 +478,7 @@ def sample(
458478
progressbar_theme: Theme | None = default_progress_theme,
459479
step=None,
460480
var_names: Sequence[str] | None = None,
461-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
481+
nuts_sampler: NutsSampler = "pymc",
462482
initvals: StartDict | Sequence[StartDict | None] | None = None,
463483
init: str = "auto",
464484
jitter_max_retries: int = 10,

tests/sampling/test_mcmc_external.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,45 +16,19 @@
1616
import numpy.testing as npt
1717
import pytest
1818

19-
from pymc import Data, Model, Normal, sample
19+
from pymc import Data, Model, Normal, modelcontext, sample
2020

2121

22-
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
23-
def test_external_nuts_sampler(recwarn, nuts_sampler):
24-
if nuts_sampler != "pymc":
25-
pytest.importorskip(nuts_sampler)
26-
27-
with Model():
28-
x = Normal("x", 100, 5)
29-
y = Data("y", [1, 2, 3, 4])
30-
Data("z", [100, 190, 310, 405])
31-
32-
Normal("L", mu=x, sigma=0.1, observed=y)
33-
34-
kwargs = {
35-
"nuts_sampler": nuts_sampler,
36-
"random_seed": 123,
37-
"chains": 2,
38-
"tune": 500,
39-
"draws": 500,
40-
"progressbar": False,
41-
"initvals": {"x": 0.0},
42-
}
43-
44-
idata1 = sample(**kwargs)
45-
idata2 = sample(**kwargs)
22+
def check_external_sampler_output(warns, idata1, idata2, sample_kwargs):
23+
nuts_sampler = sample_kwargs["nuts_sampler"]
24+
reference_kwargs = sample_kwargs.copy()
25+
reference_kwargs["nuts_sampler"] = "pymc"
4626

47-
reference_kwargs = kwargs.copy()
48-
reference_kwargs["nuts_sampler"] = "pymc"
27+
with modelcontext(None):
4928
idata_reference = sample(**reference_kwargs)
5029

51-
warns = {
52-
(warn.category, warn.message.args[0])
53-
for warn in recwarn
54-
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
55-
}
5630
expected = set()
57-
if nuts_sampler == "nutpie":
31+
if nuts_sampler.startswith("nutpie"):
5832
expected.add(
5933
(
6034
UserWarning,
@@ -74,7 +48,83 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
7448
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys()
7549

7650

51+
@pytest.fixture
52+
def pymc_model():
53+
with Model() as m:
54+
x = Normal("x", 100, 5)
55+
y = Data("y", [1, 2, 3, 4])
56+
Data("z", [100, 190, 310, 405])
57+
58+
Normal("L", mu=x, sigma=0.1, observed=y)
59+
60+
return m
61+
62+
63+
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
64+
def test_external_nuts_sampler(pymc_model, recwarn, nuts_sampler):
65+
if nuts_sampler != "pymc":
66+
pytest.importorskip(nuts_sampler)
67+
68+
sample_kwargs = dict(
69+
nuts_sampler=nuts_sampler,
70+
random_seed=123,
71+
chains=2,
72+
tune=500,
73+
draws=500,
74+
progressbar=False,
75+
initvals={"x": 0.0},
76+
)
77+
78+
with pymc_model:
79+
idata1 = sample(**sample_kwargs)
80+
idata2 = sample(**sample_kwargs)
81+
82+
warns = {
83+
(warn.category, warn.message.args[0])
84+
for warn in recwarn
85+
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
86+
}
87+
88+
check_external_sampler_output(warns, idata1, idata2, sample_kwargs)
89+
90+
91+
@pytest.mark.parametrize("backend", ["numba", "jax"], ids=["numba", "jax"])
92+
def test_numba_backend_options(pymc_model, recwarn, backend):
93+
pytest.importorskip("nutpie")
94+
pytest.importorskip(backend)
95+
96+
sample_kwargs = dict(
97+
nuts_sampler=f"nutpie[{backend}]",
98+
random_seed=123,
99+
chains=2,
100+
tune=500,
101+
draws=500,
102+
progressbar=False,
103+
initvals={"x": 0.0},
104+
)
105+
106+
with pymc_model:
107+
idata1 = sample(**sample_kwargs)
108+
idata2 = sample(**sample_kwargs)
109+
110+
warns = {
111+
(warn.category, warn.message.args[0])
112+
for warn in recwarn
113+
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning)
114+
}
115+
116+
check_external_sampler_output(warns, idata1, idata2, sample_kwargs)
117+
118+
119+
def test_invalid_nutpie_backend_raises(pymc_model):
120+
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'):
121+
with pymc_model:
122+
sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500)
123+
124+
77125
def test_step_args():
126+
pytest.importorskip("numpyro")
127+
78128
with Model() as model:
79129
a = Normal("a")
80130
idata = sample(

0 commit comments

Comments
 (0)