Skip to content

Commit d18bcfb

Browse files
Fixing nested sampling (#1871)
* fixing nested sampling * first attempt at fixing CI * import jaxns only if double precision is enabled * ignore test_nested_sampling in ci * fix doctest * implement feedbacks
1 parent aa860f7 commit d18bcfb

File tree

7 files changed

+31
-13
lines changed

7 files changed

+31
-13
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ jobs:
104104
- name: Test with pytest
105105
run: |
106106
pytest -vs --durations=20 test/infer/test_mcmc.py
107-
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py
107+
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py
108108
pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py
109109
- name: Test x64
110110
run: |
@@ -118,6 +118,9 @@ jobs:
118118
- name: Test custom prng
119119
run: |
120120
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py
121+
- name: Test nested sampling
122+
run: |
123+
JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py
121124
122125
123126
examples:

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ funsor
44
ipython
55
jax
66
jaxlib
7-
jaxns==2.4.8
7+
jaxns==2.6.3
88
Jinja2
99
matplotlib
1010
multipledispatch

docs/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535
hmc(None, None)
3636

37+
autodoc_mock_imports = ["jaxns"]
38+
3739
# -- Project information -----------------------------------------------------
3840

3941
project = "NumPyro"

examples/gaussian_shells.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def main(args):
121121

122122
if __name__ == "__main__":
123123
assert numpyro.__version__.startswith("0.15.3")
124+
124125
parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
125126
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
126127
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
@@ -133,6 +134,7 @@ def main(args):
133134
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
134135
args = parser.parse_args()
135136

137+
numpyro.enable_x64()
136138
numpyro.set_platform(args.device)
137139

138140
main(args)

numpyro/contrib/nested_sampling.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
from functools import singledispatch
55

6+
import jax
67
from jax import random, tree
78
import jax.numpy as jnp
89

910
try:
11+
import jaxns # noqa: F401
1012
from jaxns import (
11-
DefaultNestedSampler,
1213
Model,
1314
Prior,
1415
TerminationCondition,
@@ -17,11 +18,13 @@
1718
resample,
1819
summary,
1920
)
21+
from jaxns.public import DefaultNestedSampler
2022
from jaxns.utils import NestedSamplerResults
2123

2224
except ImportError as e:
2325
raise ImportError(
24-
"To use this module, please install `jaxns` package. It can be"
26+
f"{e} \n "
27+
f"To use this module, please install `jaxns>2.5` package. It can be"
2528
" installed with `pip install jaxns` with python>=3.8"
2629
) from e
2730

@@ -142,9 +145,7 @@ class NestedSampler:
142145
:param dict termination_kwargs: keyword arguments to terminate the sampler. Please
143146
refer to the upstream :meth:`jaxns.NestedSampler.__call__` method.
144147
145-
**Example**
146-
147-
.. doctest::
148+
Example::
148149
149150
>>> from jax import random
150151
>>> import jax.numpy as jnp
@@ -258,7 +259,7 @@ def prior_model():
258259

259260
default_constructor_kwargs = dict(
260261
num_live_points=model.U_ndims * 25,
261-
num_parallel_workers=1,
262+
devices=jax.devices(),
262263
max_samples=1e4,
263264
)
264265
default_termination_kwargs = dict(dlogZ=1e-4)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
"flax",
6262
"funsor>=0.4.1",
6363
"graphviz",
64-
"jaxns==2.4.8",
64+
"jaxns==2.6.3",
6565
"matplotlib",
6666
"optax>=0.0.6",
6767
"pylab-sdk", # jaxns dependency

test/contrib/test_nested_sampling.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import os
5+
46
import numpy as np
57
from numpy.testing import assert_allclose
68
import pytest
@@ -11,15 +13,23 @@
1113
import numpyro
1214

1315
try:
14-
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam
16+
if os.environ.get("JAX_ENABLE_X64"):
17+
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam
18+
1519
except ImportError:
1620
pytestmark = pytest.mark.skip(reason="jaxns is not installed")
21+
1722
import numpyro.distributions as dist
1823
from numpyro.distributions.transforms import AffineTransform, ExpTransform
1924

20-
pytestmark = pytest.mark.filterwarnings(
21-
"ignore:jax.tree_.+ is deprecated:FutureWarning"
22-
)
25+
pytestmark = [
26+
pytest.mark.filterwarnings("ignore:jax.tree_.+ is deprecated:FutureWarning"),
27+
pytest.mark.filterwarnings("ignore:JAX x64"),
28+
pytest.mark.skipif(
29+
not os.environ.get("JAX_ENABLE_X64"),
30+
reason="test suite for jaxns requires double precision",
31+
),
32+
]
2333

2434

2535
# Test helper to extract a few central moments from samples.

0 commit comments

Comments
 (0)