Skip to content

Commit 0660efa

Browse files
twieckiricardoV94
andcommitted
Allow external nuts sampler directly from sample
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 4ed38b0 commit 0660efa

File tree

4 files changed

+193
-10
lines changed

4 files changed

+193
-10
lines changed

.github/workflows/tests.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,14 +305,14 @@ jobs:
305305
env_vars: TEST_SUBSET
306306
name: ${{ matrix.os }} ${{ matrix.floatx }}
307307
fail_ci_if_error: false
308-
jax:
308+
external_samplers:
309309
strategy:
310310
matrix:
311311
os: [ubuntu-20.04]
312312
floatx: [float64]
313313
python-version: ["3.9"]
314314
test-subset:
315-
- pymc/tests/sampling/test_jax.py
315+
- pymc/tests/sampling/test_jax.py pymc/tests/sampling/test_mcmc_external.py
316316
fail-fast: false
317317
runs-on: ${{ matrix.os }}
318318
env:
@@ -360,9 +360,10 @@ jobs:
360360
conda activate pymc-test
361361
pip install -e .
362362
python --version
363-
- name: Install jax specific dependencies
363+
- name: Install external samplers
364364
run: |
365365
conda activate pymc-test
366+
conda install -c conda-forge nutpie
366367
pip install "numpyro>=0.8.0"
367368
pip install git+https://github.com/blackjax-devs/[email protected]
368369
- name: Run tests

pymc/sampling/jax.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@
4848
from pymc import Model, modelcontext
4949
from pymc.backends.arviz import find_constants, find_observations
5050
from pymc.logprob.utils import CheckParameterValue
51-
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
51+
from pymc.util import (
52+
RandomSeed,
53+
RandomState,
54+
_get_seeds_per_chain,
55+
get_default_varnames,
56+
)
5257

5358
__all__ = (
5459
"get_jaxified_graph",
@@ -308,7 +313,7 @@ def sample_blackjax_nuts(
308313
tune: int = 1000,
309314
chains: int = 4,
310315
target_accept: float = 0.8,
311-
random_seed: Optional[RandomSeed] = None,
316+
random_seed: Optional[RandomState] = None,
312317
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
313318
model: Optional[Model] = None,
314319
var_names: Optional[Sequence[str]] = None,
@@ -518,7 +523,7 @@ def sample_numpyro_nuts(
518523
tune: int = 1000,
519524
chains: int = 4,
520525
target_accept: float = 0.8,
521-
random_seed: Optional[RandomSeed] = None,
526+
random_seed: Optional[RandomState] = None,
522527
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
523528
model: Optional[Model] = None,
524529
var_names: Optional[Sequence[str]] = None,

pymc/sampling/mcmc.py

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,94 @@ def all_continuous(vars):
221221
return True
222222

223223

224+
def _sample_external_nuts(
225+
sampler: str,
226+
draws: int,
227+
tune: int,
228+
chains: int,
229+
target_accept: float,
230+
random_seed: Union[RandomState, None],
231+
initvals: Union[StartDict, Sequence[Optional[StartDict]], None],
232+
model: Model,
233+
progressbar: bool,
234+
idata_kwargs: Optional[Dict],
235+
**kwargs,
236+
):
237+
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
238+
239+
if sampler == "nutpie":
240+
try:
241+
import nutpie
242+
except ImportError as err:
243+
raise ImportError(
244+
"nutpie not found. Install it with conda install -c conda-forge nutpie"
245+
) from err
246+
247+
if initvals is not None:
248+
warnings.warn(
249+
"`initvals` are currently not passed to nutpie sampler. "
250+
"Use `init_mean` kwarg following nutpie specification instead.",
251+
UserWarning,
252+
)
253+
254+
if idata_kwargs is not None:
255+
warnings.warn(
256+
"`idata_kwargs` are currently ignored by the nutpie sampler",
257+
UserWarning,
258+
)
259+
260+
compiled_model = nutpie.compile_pymc_model(model)
261+
idata = nutpie.sample(
262+
compiled_model,
263+
draws=draws,
264+
tune=tune,
265+
chains=chains,
266+
target_accept=target_accept,
267+
seed=_get_seeds_per_chain(random_seed, 1)[0],
268+
progress_bar=progressbar,
269+
**kwargs,
270+
)
271+
return idata
272+
273+
elif sampler == "numpyro":
274+
import pymc.sampling.jax as pymc_jax
275+
276+
idata = pymc_jax.sample_numpyro_nuts(
277+
draws=draws,
278+
tune=tune,
279+
chains=chains,
280+
target_accept=target_accept,
281+
random_seed=random_seed,
282+
initvals=initvals,
283+
model=model,
284+
progressbar=progressbar,
285+
idata_kwargs=idata_kwargs,
286+
**kwargs,
287+
)
288+
return idata
289+
290+
elif sampler == "blackjax":
291+
import pymc.sampling.jax as pymc_jax
292+
293+
idata = pymc_jax.sample_blackjax_nuts(
294+
draws=draws,
295+
tune=tune,
296+
chains=chains,
297+
target_accept=target_accept,
298+
random_seed=random_seed,
299+
initvals=initvals,
300+
model=model,
301+
idata_kwargs=idata_kwargs,
302+
**kwargs,
303+
)
304+
return idata
305+
306+
else:
307+
raise ValueError(
308+
f"Sampler {sampler} not found. Choose one of ['nutpie', 'numpyro', 'blackjax', 'pymc']."
309+
)
310+
311+
224312
def sample(
225313
draws: int = 1000,
226314
step=None,
@@ -239,6 +327,7 @@ def sample(
239327
callback=None,
240328
jitter_max_retries: int = 10,
241329
*,
330+
nuts_sampler: str = "pymc",
242331
return_inferencedata: bool = True,
243332
keep_warning_stat: bool = False,
244333
idata_kwargs: dict = None,
@@ -257,6 +346,7 @@ def sample(
257346
init : str
258347
Initialization method to use for auto-assigned NUTS samplers. See `pm.init_nuts` for a list
259348
of all options. This argument is ignored when manually passing the NUTS step method.
349+
Only applicable to the pymc nuts sampler.
260350
step : function or iterable of functions
261351
A step function or collection of functions. If there are variables without step methods,
262352
step methods for those variables will be assigned automatically. By default the NUTS step
@@ -306,6 +396,10 @@ def sample(
306396
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform
307397
jitter that yields a finite probability. This applies to ``jitter+adapt_diag`` and
308398
``jitter+adapt_full`` init methods.
399+
nuts_sampler : str
400+
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
401+
This requires the chosen sampler to be installed.
402+
All samplers, except "pymc", require the full model to be continuous.
309403
return_inferencedata : bool
310404
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a
311405
`MultiTrace` (False). Defaults to `True`.
@@ -401,7 +495,7 @@ def sample(
401495
if "nuts" in kwargs:
402496
kwargs["nuts"]["target_accept"] = kwargs.pop("target_accept")
403497
else:
404-
kwargs = {"nuts": {"target_accept": kwargs.pop("target_accept")}}
498+
kwargs["nuts"] = {"target_accept": kwargs.pop("target_accept")}
405499
if isinstance(trace, list):
406500
raise DeprecationWarning(
407501
"We have removed support for partial traces because it simplified things."
@@ -441,8 +535,6 @@ def sample(
441535
msg = "Only %s samples in chain." % draws
442536
_log.warning(msg)
443537

444-
draws += tune
445-
446538
auto_nuts_init = True
447539
if step is not None:
448540
if isinstance(step, CompoundStep):
@@ -455,6 +547,25 @@ def sample(
455547
initial_points = None
456548
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
457549

550+
if nuts_sampler != "pymc":
551+
if not isinstance(step, NUTS):
552+
raise ValueError(
553+
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
554+
)
555+
return _sample_external_nuts(
556+
sampler=nuts_sampler,
557+
draws=draws,
558+
tune=tune,
559+
chains=chains,
560+
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
561+
random_seed=random_seed,
562+
initvals=initvals,
563+
model=model,
564+
progressbar=progressbar,
565+
idata_kwargs=idata_kwargs,
566+
**kwargs,
567+
)
568+
458569
if isinstance(step, list):
459570
step = CompoundStep(step)
460571
elif isinstance(step, NUTS) and auto_nuts_init:
@@ -503,7 +614,7 @@ def sample(
503614
)
504615

505616
sample_args = {
506-
"draws": draws,
617+
"draws": draws + tune, # FIXME: Why is tune added to draws?
507618
"step": step,
508619
"start": initial_points,
509620
"traces": traces,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
18+
from pymc import Model, Normal, sample
19+
20+
pytest.importorskip("nutpie")
21+
pytest.importorskip("blackjax")
22+
pytest.importorskip("numpyro")
23+
24+
# turns all warnings into errors for this module
25+
pytestmark = pytest.mark.filterwarnings("error")
26+
27+
28+
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
29+
def test_external_nuts_sampler(recwarn, nuts_sampler):
30+
with Model():
31+
Normal("x")
32+
33+
kwargs = dict(
34+
nuts_sampler=nuts_sampler,
35+
random_seed=123,
36+
chains=2,
37+
tune=500,
38+
draws=500,
39+
progressbar=False,
40+
initvals={"x": 0.0},
41+
)
42+
43+
idata1 = sample(**kwargs)
44+
idata2 = sample(**kwargs)
45+
46+
warns = {
47+
(warn.category, warn.message.args[0])
48+
for warn in recwarn
49+
if warn.category is not FutureWarning
50+
}
51+
expected = set()
52+
if nuts_sampler != "pymc":
53+
expected.add((UserWarning, "Use of external NUTS sampler is still experimental"))
54+
if nuts_sampler == "nutpie":
55+
expected.add(
56+
(
57+
UserWarning,
58+
"`initvals` are currently not passed to nutpie sampler. "
59+
"Use `init_mean` kwarg following nutpie specification instead.",
60+
)
61+
)
62+
assert warns == expected
63+
64+
assert idata1.posterior.chain.size == 2
65+
assert idata1.posterior.draw.size == 500
66+
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)

0 commit comments

Comments
 (0)