Skip to content
Open
10,218 changes: 10,218 additions & 0 deletions docs/source/how_to/ConversionGuideNumPyro.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Thus, to install all user facing optional dependencies you should use `arviz-bas
tutorial/WorkingWithDataTree
tutorial/label_guide
how_to/ConversionGuideEmcee
how_to/ConversionGuideNumPyro
ArviZ in Context <https://arviz-devs.github.io/EABM/>
:::

Expand Down
72 changes: 63 additions & 9 deletions external_tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,30 @@ def _numpyro_noncentered_model(J, sigma, y=None):
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)


def _numpyro_noncentered_guide(J, sigma, y=None):
import jax
import numpyro
import numpyro.distributions as dist

# Variational parameters for mu
mu_loc = numpyro.param("mu_loc", 0.0)
mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))

# Variational parameters for tau (positive support)
tau_loc = numpyro.param("tau_loc", 1.0)
tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
numpyro.sample("tau", dist.LogNormal(jax.numpy.log(tau_loc), tau_scale))

# Variational parameters for eta
eta_loc = numpyro.param("eta_loc", jax.numpy.zeros(J))
eta_scale = numpyro.param("eta_scale", jax.numpy.ones(J), constraint=dist.constraints.positive)
with numpyro.plate("J", J):
numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))


def numpyro_schools_model(data, draws, chains):
"""Centered eight schools implementation in NumPyro."""
"""Non-centered eight schools implementation in NumPyro."""
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS

Expand All @@ -133,6 +155,35 @@ def numpyro_schools_model(data, draws, chains):
return mcmc


def numpyro_schools_model_svi(data, draws, chains):
"""Non-centered eight schools implementation in NumPyro."""
from jax.random import PRNGKey
from numpyro.infer import SVI, Trace_ELBO, init_to_sample
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adam

guide = AutoNormal(_numpyro_noncentered_model, init_loc_fn=init_to_sample())
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
svi_result = svi.run(PRNGKey(0), 4000, **data)
return {"svi": svi, "svi_result": svi_result, "model_kwargs": data}


def numpyro_schools_model_svi_custom_guide(data, draws, chains):
"""Non-centered eight schools implementation in NumPyro."""
from jax.random import PRNGKey
from numpyro.infer import SVI, Trace_ELBO
from numpyro.optim import Adam

guide = _numpyro_noncentered_guide
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
svi_result = svi.run(PRNGKey(0), 4000, **data)
return {
"svi": svi,
"svi_result": svi_result,
"model_kwargs": data,
}


def pystan_noncentered_schools(data, draws, chains):
"""Non-centered eight schools implementation for pystan."""
schools_code = """
Expand Down Expand Up @@ -188,10 +239,12 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
"""Load pystan, emcee, and pyro models from pickle."""
here = os.path.dirname(os.path.abspath(__file__))
supported = (
# ("pystan", pystan_noncentered_schools),
("emcee", emcee_schools_model),
# ("pyro", pyro_noncentered_schools),
("numpyro", numpyro_schools_model),
# ("pystan", pystan_noncentered_schools, None),
("emcee", emcee_schools_model, None),
# ("pyro", pyro_noncentered_schools, None),
("numpyro", numpyro_schools_model, None),
("numpyro", numpyro_schools_model_svi, "numpyro_svi"),
("numpyro", numpyro_schools_model_svi_custom_guide, "numpyro_svi_custom_guide"),
)
data_directory = os.path.join(here, "saved_models")
if not os.path.isdir(data_directory):
Expand All @@ -201,7 +254,8 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
if isinstance(libs, str):
libs = [libs]

for library_name, func in supported:
for library_name, func, addl_model_key in supported:
model_key = addl_model_key or library_name
if libs is not None and library_name not in libs:
continue
library = library_handle(library_name)
Expand All @@ -214,7 +268,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):

py_version = sys.version_info
fname = (
f"{py_version.major}.{py_version.minor}_{library.__name__}_{library.__version__}"
f"{py_version.major}.{py_version.minor}_{model_key}_{library.__version__}"
f"_{sys.platform}_{draws}_{chains}.pkl.gzip"
)

Expand All @@ -225,11 +279,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
_log.info("Generating and caching %s", fname)
cloudpickle.dump(func(eight_schools_data, draws, chains), buff)
except AttributeError as err:
raise AttributeError(f"Failed caching {library_name}") from err
raise AttributeError(f"Failed caching {model_key}") from err

with gzip.open(path, "rb") as buff:
_log.info("Loading %s from cache", fname)
models[library.__name__] = cloudpickle.load(buff)
models[model_key] = cloudpickle.load(buff)

return models

Expand Down
Loading