-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow access to different nutpie backends via pip-style syntax #7498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,17 +17,13 @@ | |||||||||
import contextlib | ||||||||||
import logging | ||||||||||
import pickle | ||||||||||
import re | ||||||||||
import sys | ||||||||||
import time | ||||||||||
import warnings | ||||||||||
|
||||||||||
from collections.abc import Callable, Iterator, Mapping, Sequence | ||||||||||
from typing import ( | ||||||||||
Any, | ||||||||||
Literal, | ||||||||||
TypeAlias, | ||||||||||
overload, | ||||||||||
) | ||||||||||
from typing import Any, Literal, TypeAlias, cast, get_args, overload | ||||||||||
|
||||||||||
import numpy as np | ||||||||||
import pytensor.gradient as tg | ||||||||||
|
@@ -86,6 +82,14 @@ | |||||||||
|
||||||||||
Step: TypeAlias = BlockedStep | CompoundStep | ||||||||||
|
||||||||||
ExternalNutsSampler = Literal["nutpie", "numpyro", "blackjax"] | ||||||||||
NutsSampler = Literal["pymc"] | ExternalNutsSampler | ||||||||||
NutpieBackend = Literal["numba", "jax"] | ||||||||||
|
||||||||||
|
||||||||||
NUTPIE_BACKENDS = get_args(NutpieBackend) | ||||||||||
NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba") | ||||||||||
|
||||||||||
|
||||||||||
class SamplingIteratorCallback(Protocol): | ||||||||||
"""Signature of the callable that may be passed to `pm.sample(callable=...)`.""" | ||||||||||
|
@@ -262,7 +266,7 @@ | |||||||||
|
||||||||||
|
||||||||||
def _sample_external_nuts( | ||||||||||
sampler: Literal["nutpie", "numpyro", "blackjax"], | ||||||||||
sampler: ExternalNutsSampler, | ||||||||||
draws: int, | ||||||||||
tune: int, | ||||||||||
chains: int, | ||||||||||
|
@@ -280,7 +284,7 @@ | |||||||||
if nuts_sampler_kwargs is None: | ||||||||||
nuts_sampler_kwargs = {} | ||||||||||
|
||||||||||
if sampler == "nutpie": | ||||||||||
if sampler.startswith("nutpie"): | ||||||||||
try: | ||||||||||
import nutpie | ||||||||||
except ImportError as err: | ||||||||||
|
@@ -313,6 +317,23 @@ | |||||||||
model, | ||||||||||
**compile_kwargs, | ||||||||||
) | ||||||||||
|
||||||||||
def extract_backend(string: str) -> NutpieBackend: | ||||||||||
match = re.search(r"(?<=\[)[^\]]+(?=\])", string) | ||||||||||
if match is None: | ||||||||||
return NUTPIE_DEFAULT_BACKEND | ||||||||||
result = cast(NutpieBackend, match.group(0)) | ||||||||||
if result not in NUTPIE_BACKENDS: | ||||||||||
last_option = f"{NUTPIE_BACKENDS[-1]}" | ||||||||||
expected = ( | ||||||||||
", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"' | ||||||||||
) | ||||||||||
raise ValueError(f'Expected one of {expected}; found "{result}"') | ||||||||||
|
raise ValueError(f'Expected one of {expected}; found "{result}"') | |
raise ValueError( | |
'Could not parse nutpie backend. Expected one of {expected}; found "{result}"' | |
) |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,45 +16,19 @@ | |||||||||||
import numpy.testing as npt | ||||||||||||
import pytest | ||||||||||||
|
||||||||||||
from pymc import Data, Model, Normal, sample | ||||||||||||
from pymc import Data, Model, Normal, modelcontext, sample | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) | ||||||||||||
def test_external_nuts_sampler(recwarn, nuts_sampler): | ||||||||||||
if nuts_sampler != "pymc": | ||||||||||||
pytest.importorskip(nuts_sampler) | ||||||||||||
|
||||||||||||
with Model(): | ||||||||||||
x = Normal("x", 100, 5) | ||||||||||||
y = Data("y", [1, 2, 3, 4]) | ||||||||||||
Data("z", [100, 190, 310, 405]) | ||||||||||||
|
||||||||||||
Normal("L", mu=x, sigma=0.1, observed=y) | ||||||||||||
|
||||||||||||
kwargs = { | ||||||||||||
"nuts_sampler": nuts_sampler, | ||||||||||||
"random_seed": 123, | ||||||||||||
"chains": 2, | ||||||||||||
"tune": 500, | ||||||||||||
"draws": 500, | ||||||||||||
"progressbar": False, | ||||||||||||
"initvals": {"x": 0.0}, | ||||||||||||
} | ||||||||||||
|
||||||||||||
idata1 = sample(**kwargs) | ||||||||||||
idata2 = sample(**kwargs) | ||||||||||||
def check_external_sampler_output(warns, idata1, idata2, sample_kwargs): | ||||||||||||
nuts_sampler = sample_kwargs["nuts_sampler"] | ||||||||||||
reference_kwargs = sample_kwargs.copy() | ||||||||||||
reference_kwargs["nuts_sampler"] = "pymc" | ||||||||||||
|
||||||||||||
reference_kwargs = kwargs.copy() | ||||||||||||
reference_kwargs["nuts_sampler"] = "pymc" | ||||||||||||
with modelcontext(None): | ||||||||||||
idata_reference = sample(**reference_kwargs) | ||||||||||||
|
||||||||||||
warns = { | ||||||||||||
(warn.category, warn.message.args[0]) | ||||||||||||
for warn in recwarn | ||||||||||||
if warn.category not in (FutureWarning, DeprecationWarning, RuntimeWarning) | ||||||||||||
} | ||||||||||||
expected = set() | ||||||||||||
if nuts_sampler == "nutpie": | ||||||||||||
if nuts_sampler.startswith("nutpie"): | ||||||||||||
expected.add( | ||||||||||||
( | ||||||||||||
UserWarning, | ||||||||||||
|
@@ -74,7 +48,84 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): | |||||||||||
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.fixture | ||||||||||||
def pymc_model(): | ||||||||||||
with Model() as m: | ||||||||||||
x = Normal("x", 100, 5) | ||||||||||||
y = Data("y", [1, 2, 3, 4]) | ||||||||||||
Data("z", [100, 190, 310, 405]) | ||||||||||||
|
||||||||||||
Normal("L", mu=x, sigma=0.1, observed=y) | ||||||||||||
|
||||||||||||
return m | ||||||||||||
|
||||||||||||
|
||||||||||||
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) | ||||||||||||
|
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) | |
@pytest.mark.parametrize( | |
"nuts_sampler", | |
["pymc", "nutpie", "nutpie[jax]", "blackjax", "numpyro"], | |
) |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): | |
with pytest.raises( | |
ValueError, | |
match='Could not parse nutpie backend. Expected one of "numba" or "jax"; found "invalid"', | |
): |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with pytest.raises(ValueError, match="Could not parse nutpie backend. Found 'nutpie[bad'"): | |
with pymc_model: | |
sample(nuts_sampler="nutpie[bad", random_seed=123, chains=2, tune=500, draws=500) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also get a
None
match if the string is misformatted. For example,nutpie[jax
would return aNone
match. I suggest that you test exact equality to set the default option, and if you getNone
then raise aValueError
.