Skip to content
This repository was archived by the owner on Jan 6, 2026. It is now read-only.

Commit 20dd6c1

Browse files
committed
0.0.2
1 parent e132136 commit 20dd6c1

File tree

14 files changed

+390
-249
lines changed

14 files changed

+390
-249
lines changed

examples/imgs/funnel_NMC.png

216 KB
Loading

examples/imgs/funnel_NUTS.png

238 KB
Loading

examples/imgs/funnel_centered.png

179 KB
Loading
273 KB
Loading

examples/mixture_model.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

examples/neals_funnel.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
#
5+
# An implementation of Neal's Funnel with NUTS and NMC kernels.
6+
# Full API-compatibility of NMC with NumPyro's MCMC ABCs.
7+
#
8+
# (cfr.: http://num.pyro.ai/en/latest/examples/funnel.html)
9+
#
10+
11+
import os
12+
13+
from jax import random
14+
import jax.numpy as jnp
15+
from jax.config import config
16+
17+
import numpyro
18+
import numpyro.distributions as dist
19+
from numpyro.infer import MCMC, NUTS
20+
21+
from nmc_numpyro import NMC
22+
23+
24+
# Usual workarounds to force-enable JIT and avoid GPU OOMs
25+
config.update("jax_disable_jit", False)
26+
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".84"
27+
28+
# Set initial random seed
29+
rng_key = random.PRNGKey(0)
30+
31+
32+
# Actual Neal's Funnel model (10-D): no modification from w.r.t. original code
33+
def model(dim=10):
34+
y = numpyro.sample("y", dist.Normal(0, 3))
35+
numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))
36+
37+
38+
# Kernel instantiation: exact same API
39+
nuts_kernel = NUTS(model)
40+
nmc_kernel = NMC(model)
41+
42+
43+
# "Run Markov Chain, run!"
44+
nuts_mcmc_runner = MCMC(nuts_kernel, num_samples=200, num_warmup=200)
45+
nmc_mcmc_runner = MCMC(nmc_kernel, num_samples=200, num_warmup=200)
46+
47+
# Compare
48+
print("<><><><><><><><><><><><><><><><><><><><><><><><>")
49+
nuts_mcmc_runner.run(rng_key)
50+
nuts_mcmc_runner.print_summary(exclude_deterministic=True)
51+
print("<><><><><><><><><><><><><><><><><><><><><><><><>")
52+
nmc_mcmc_runner.run(rng_key)
53+
nmc_mcmc_runner.print_summary(exclude_deterministic=True)
54+
print("<><><><><><><><><><><><><><><><><><><><><><><><>")

examples/neals_funnel_showdown.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
# Imports
5+
import os
6+
import matplotlib.pyplot as plt
7+
from jax import random
8+
import jax.numpy as jnp
9+
from jax.config import config
10+
import numpyro
11+
import numpyro.distributions as dist
12+
from numpyro.handlers import reparam
13+
from numpyro.infer import MCMC, NUTS, Predictive
14+
from nmc_numpyro import NMC
15+
from numpyro.infer.reparam import LocScaleReparam
16+
17+
18+
# The model itself
19+
def model(dim=10):
20+
y = numpyro.sample("y", dist.Normal(0, 3))
21+
numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))
22+
23+
24+
# The automatically-reparameterized model (after Gorinova et al., 2020)
25+
reparam_model = reparam(model, config={"x": LocScaleReparam(0)})
26+
27+
28+
# Wrapper functions
29+
def run_inference(model, kernel_fx, rng_key):
30+
kernel = kernel_fx(model)
31+
mcmc = MCMC(
32+
kernel,
33+
# Edit directly here!
34+
num_warmup=1000,
35+
num_samples=25000,
36+
num_chains=1,
37+
progress_bar=True,
38+
)
39+
mcmc.run(rng_key)
40+
return mcmc.get_samples()
41+
42+
43+
def run_nuts_vanilla(rng_key):
44+
return run_inference(model, NUTS, rng_key)
45+
46+
47+
def run_nuts_reparam(rng_key):
48+
return run_inference(reparam_model, NUTS, rng_key)
49+
50+
51+
def run_nmc_vanilla(rng_key):
52+
return run_inference(model, NMC, rng_key)
53+
54+
55+
def run_nmc_reparam(rng_key):
56+
return run_inference(reparam_model, NMC, rng_key)
57+
58+
59+
# Main function
60+
61+
62+
def main():
63+
initial_rng_key = random.PRNGKey(0)
64+
initial_rng_key_p = random.PRNGKey(1)
65+
66+
# NUTS,vanilla
67+
nuv = run_nuts_vanilla(initial_rng_key)
68+
69+
# NUTS, reparameterized
70+
nur = run_nuts_reparam(initial_rng_key)
71+
nurp = Predictive(reparam_model, nur, return_sites=["x", "y"])(initial_rng_key_p)
72+
73+
# NMC, vanilla
74+
nmv = run_nmc_vanilla(initial_rng_key).z
75+
76+
# NMC, reparameterized
77+
nmr = run_nmc_reparam(initial_rng_key).z
78+
nmrp = Predictive(reparam_model, nmr, return_sites=["x", "y"])(initial_rng_key_p)
79+
80+
#
81+
# PLOTTING
82+
#
83+
84+
# NUTS vs Reparameterized NUTS
85+
fig1, (ax1, ax2) = plt.subplots(
86+
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
87+
)
88+
89+
ax1.plot(nuv["x"][:, 0], nuv["y"], "go", alpha=0.3)
90+
ax1.set(
91+
xlim=(-20, 20),
92+
ylim=(-9, 9),
93+
xlabel="x[0]",
94+
ylabel="y",
95+
title="Funnel samples: NUTS, centered parameterization",
96+
)
97+
98+
ax2.plot(nurp["x"][:, 0], nurp["y"], "go", alpha=0.3)
99+
ax2.set(
100+
xlim=(-20, 20),
101+
ylim=(-9, 9),
102+
xlabel="x[0]",
103+
ylabel="y",
104+
title="Funnel samples: NUTS, non-centered parameterization",
105+
)
106+
107+
plt.savefig("imgs/funnel_NUTS.png")
108+
109+
# NMC vs Reparameterized NMC
110+
fig2, (ax3, ax4) = plt.subplots(
111+
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
112+
)
113+
114+
ax3.plot(nmv["x"][:, 0], nmv["y"], "go", alpha=0.3)
115+
ax3.set(
116+
xlim=(-20, 20),
117+
ylim=(-9, 9),
118+
xlabel="x[0]",
119+
ylabel="y",
120+
title="Funnel samples: NMC, centered parameterization",
121+
)
122+
123+
ax4.plot(nmrp["x"][:, 0], nmrp["y"], "go", alpha=0.3)
124+
ax4.set(
125+
xlim=(-20, 20),
126+
ylim=(-9, 9),
127+
xlabel="x[0]",
128+
ylabel="y",
129+
title="Funnel samples: NMC, non-centered parameterization",
130+
)
131+
132+
plt.savefig("imgs/funnel_NMC.png")
133+
134+
# NUTS vs NMC
135+
fig3, (ax5, ax6) = plt.subplots(
136+
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
137+
)
138+
139+
ax5.plot(nuv["x"][:, 0], nuv["y"], "go", alpha=0.3)
140+
ax5.set(
141+
xlim=(-20, 20),
142+
ylim=(-9, 9),
143+
xlabel="x[0]",
144+
ylabel="y",
145+
title="Funnel samples: NUTS, centered parameterization",
146+
)
147+
148+
ax6.plot(nmv["x"][:, 0], nmv["y"], "go", alpha=0.3)
149+
ax6.set(
150+
xlim=(-20, 20),
151+
ylim=(-9, 9),
152+
xlabel="x[0]",
153+
ylabel="y",
154+
title="Funnel samples: NMC, centered parameterization",
155+
)
156+
157+
plt.savefig("imgs/funnel_centered.png")
158+
159+
# Reparameterized (NUTS vs NMC)
160+
fig4, (ax7, ax8) = plt.subplots(
161+
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
162+
)
163+
164+
ax7.plot(nurp["x"][:, 0], nurp["y"], "go", alpha=0.3)
165+
ax7.set(
166+
xlim=(-20, 20),
167+
ylim=(-9, 9),
168+
xlabel="x[0]",
169+
ylabel="y",
170+
title="Funnel samples: NUTS, non-centered parameterization",
171+
)
172+
173+
ax8.plot(nmrp["x"][:, 0], nmrp["y"], "go", alpha=0.3)
174+
ax8.set(
175+
xlim=(-20, 20),
176+
ylim=(-9, 9),
177+
xlabel="x[0]",
178+
ylabel="y",
179+
title="Funnel samples: NMC, non-centered parameterization",
180+
)
181+
182+
plt.savefig("imgs/funnel_noncentered.png")
183+
184+
185+
if __name__ == "__main__":
186+
187+
# Usual workarounds to force-enable JIT and avoid GPU OOMs
188+
config.update("jax_disable_jit", False)
189+
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".84"
190+
191+
# Edit here!
192+
numpyro.set_platform("cpu")
193+
numpyro.set_host_device_count(1)
194+
195+
main()

0 commit comments

Comments
 (0)