Skip to content

Commit 0d103ba

Browse files
authored
Exploring nested sampling (#807)
* subclass Reparam * add some examples * add diagnostic * update notebook with resampling * restrict the domain * compare with mcmc results * update a working multimodal result * temp save * add example and documentation * fix typo * fix typos and add to toc * add contrib.rst * cleanup TruncatedNormal/Cauchy and Uniform * make some changes for new jaxns api, not working yet * fix some typos * fix typo at Uniform logprob * make truncated distribution tests pass * fix lint * fix failing tests * update for the new api * add icdf method to tensorflow * fix reparam batch logic * fix event dim of projectednormalreparam * add missing icdf methods * enum is working * add gaussian shell example * add tests for gaussian shells * run make license * add jaxns to the dependency * add EXPERIMENTAL warning * fix docs * temporary save * add tests for nested sampling * adjust precision * fix tests * make format * add jaxns to docs requirement * pump jaxns version * use jaxns default arguments * increase threshold of discretehmcgibbs * add a better docstring for diagnostics * increase threshold for hmcgibss in x64 * monkeypatch jaxns x64 dtype * make sure that float32 can be used * perform optimization * avoid singularity at 0 * add note about enable x64 * add print summary * address comments during pair review * revert change at tfp * restrict tfp version
1 parent 5bcea01 commit 0d103ba

File tree

11 files changed

+620
-0
lines changed

11 files changed

+620
-0
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ flax
33
funsor
44
jax>=0.1.65
55
jaxlib>=0.1.45
6+
jaxns==0.0.7
67
optax==0.0.6
78
nbsphinx>=0.8.5
89
sphinx-gallery
131 KB
Loading

docs/source/api.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,12 @@ Inference
3737
optimizers
3838
diagnostics
3939
utilities
40+
41+
Contributed Code
42+
----------------
43+
44+
.. toctree::
45+
:glob:
46+
:maxdepth: 1
47+
48+
contrib

docs/source/contrib.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Nested Sampling
2+
~~~~~~~~~~~~~~~
3+
4+
.. autoclass:: numpyro.contrib.nested_sampling.NestedSampler
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
:member-order: bysource

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ NumPyro documentation
3737
examples/annotation
3838
examples/hmm_enum
3939
examples/capture_recapture
40+
examples/gaussian_shells
4041
tutorials/discrete_imputation
4142

4243
.. nbgallery::

examples/gaussian_shells.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Example: Nested Sampling for Gaussian Shells
6+
============================================
7+
8+
This example illustrates the usage of the contrib class NestedSampler,
9+
which is a wrapper of `jaxns` library ([1]) to be used for NumPyro models.
10+
11+
Here we will replicate the Gaussian Shells demo at [2] and compare against
12+
NUTS sampler.
13+
14+
**References:**
15+
16+
1. jaxns library: https://github.com/Joshuaalbert/jaxns
17+
2. dynesty's Gaussian Shells demo:
18+
https://github.com/joshspeagle/dynesty/blob/master/demos/Examples%20--%20Gaussian%20Shells.ipynb
19+
"""
20+
21+
import argparse
22+
23+
import matplotlib.pyplot as plt
24+
25+
from jax import random
26+
import jax.numpy as jnp
27+
28+
import numpyro
29+
from numpyro.contrib.nested_sampling import NestedSampler
30+
import numpyro.distributions as dist
31+
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
32+
33+
34+
class GaussianShell(dist.Distribution):
35+
support = dist.constraints.real_vector
36+
37+
def __init__(self, loc, radius, width):
38+
self.loc, self.radius, self.width = loc, radius, width
39+
super().__init__(batch_shape=loc.shape[:-1], event_shape=loc.shape[-1:])
40+
41+
def sample(self, key, sample_shape=()):
42+
return jnp.zeros(
43+
sample_shape + self.shape()
44+
) # a dummy sample to initialize the samplers
45+
46+
def log_prob(self, value):
47+
normalizer = (-0.5) * (jnp.log(2.0 * jnp.pi) + 2.0 * jnp.log(self.width))
48+
d = jnp.linalg.norm(value - self.loc, axis=-1)
49+
return normalizer - 0.5 * ((d - self.radius) / self.width) ** 2
50+
51+
52+
def model(center1, center2, radius, width, enum=False):
53+
z = numpyro.sample(
54+
"z", dist.Bernoulli(0.5), infer={"enumerate": "parallel"} if enum else {}
55+
)
56+
x = numpyro.sample("x", dist.Uniform(-6.0, 6.0).expand([2]).to_event(1))
57+
center = jnp.stack([center1, center2])[z]
58+
numpyro.sample("shell", GaussianShell(center, radius, width), obs=x)
59+
60+
61+
def run_inference(args, data):
62+
print("=== Performing Nested Sampling ===")
63+
ns = NestedSampler(model)
64+
ns.run(random.PRNGKey(0), **data, enum=args.enum)
65+
ns.print_summary()
66+
# samples obtained from nested sampler are weighted, so
67+
# we need to provide random key to resample from those weighted samples
68+
ns_samples = ns.get_samples(random.PRNGKey(1), num_samples=args.num_samples)
69+
70+
print("\n=== Performing MCMC Sampling ===")
71+
if args.enum:
72+
mcmc = MCMC(
73+
NUTS(model), num_warmup=args.num_warmup, num_samples=args.num_samples
74+
)
75+
else:
76+
mcmc = MCMC(
77+
DiscreteHMCGibbs(NUTS(model)),
78+
num_warmup=args.num_warmup,
79+
num_samples=args.num_samples,
80+
)
81+
mcmc.run(random.PRNGKey(2), **data)
82+
mcmc.print_summary()
83+
mcmc_samples = mcmc.get_samples()
84+
85+
return ns_samples["x"], mcmc_samples["x"]
86+
87+
88+
def main(args):
89+
data = dict(
90+
radius=2.0,
91+
width=0.1,
92+
center1=jnp.array([-3.5, 0.0]),
93+
center2=jnp.array([3.5, 0.0]),
94+
)
95+
ns_samples, mcmc_samples = run_inference(args, data)
96+
97+
# plotting
98+
fig, (ax1, ax2) = plt.subplots(
99+
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
100+
)
101+
102+
ax1.plot(mcmc_samples[:, 0], mcmc_samples[:, 1], "ro", alpha=0.2)
103+
ax1.set(
104+
xlim=(-6, 6),
105+
ylim=(-2.5, 2.5),
106+
ylabel="x[1]",
107+
title="Gaussian-shell samples using NUTS",
108+
)
109+
110+
ax2.plot(ns_samples[:, 0], ns_samples[:, 1], "ro", alpha=0.2)
111+
ax2.set(
112+
xlim=(-6, 6),
113+
ylim=(-2.5, 2.5),
114+
xlabel="x[0]",
115+
ylabel="x[1]",
116+
title="Gaussian-shell samples using Nested Sampler",
117+
)
118+
119+
plt.savefig("gaussian_shells_plot.pdf")
120+
121+
122+
if __name__ == "__main__":
123+
assert numpyro.__version__.startswith("0.6.0")
124+
parser = argparse.ArgumentParser(description="Nested sampler for Gaussian shells")
125+
parser.add_argument("-n", "--num-samples", nargs="?", default=10000, type=int)
126+
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
127+
parser.add_argument(
128+
"--enum",
129+
action="store_true",
130+
default=False,
131+
help="whether to enumerate over the discrete latent variable",
132+
)
133+
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
134+
args = parser.parse_args()
135+
136+
numpyro.set_platform(args.device)
137+
138+
main(args)

0 commit comments

Comments
 (0)