Skip to content

Commit 10166b6

Browse files
Added from_numpyro_svi converter (#95)
* added from_numpyro_svi support updated typing for docstub fixed test_inferred_dims_univariate improved load_cached_models fixed svi docstring and SVIWrapper changed numpyro svi signature fixed from_numpyro_svi docstring added custom guide support and test added test support for custom guides * changed input for form_numpyro_svi from guide to SVI instance * improved comments, imports * added tests for SVIWrapper * added NumPyro conversion guide * stripped notebook output * removed thinning, updated docstrings and refs * updated sample_dims * fix typo in io_numpyro docstring Co-authored-by: Tomás Capretto <[email protected]> * regenerate stub files after rebase * updated conversion guide * fixed issue in numpyro conversion guide with event dim labelling --------- Co-authored-by: Tomás Capretto <[email protected]>
1 parent a634b72 commit 10166b6

File tree

8 files changed

+10803
-81
lines changed

8 files changed

+10803
-81
lines changed

docs/source/how_to/ConversionGuideNumPyro.ipynb

Lines changed: 10237 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Thus, to install all user facing optional dependencies you should use `arviz-bas
6767
tutorial/WorkingWithDataTree
6868
tutorial/label_guide
6969
how_to/ConversionGuideEmcee
70+
how_to/ConversionGuideNumPyro
7071
ArviZ in Context <https://arviz-devs.github.io/EABM/>
7172
:::
7273

external_tests/helpers.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,30 @@ def _numpyro_noncentered_model(J, sigma, y=None):
109109
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
110110

111111

112+
def _numpyro_noncentered_guide(J, sigma, y=None):
113+
import jax
114+
import numpyro
115+
import numpyro.distributions as dist
116+
117+
# Variational parameters for mu
118+
mu_loc = numpyro.param("mu_loc", 0.0)
119+
mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
120+
numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
121+
122+
# Variational parameters for tau (positive support)
123+
tau_loc = numpyro.param("tau_loc", 1.0)
124+
tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
125+
numpyro.sample("tau", dist.LogNormal(jax.numpy.log(tau_loc), tau_scale))
126+
127+
# Variational parameters for eta
128+
eta_loc = numpyro.param("eta_loc", jax.numpy.zeros(J))
129+
eta_scale = numpyro.param("eta_scale", jax.numpy.ones(J), constraint=dist.constraints.positive)
130+
with numpyro.plate("J", J):
131+
numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))
132+
133+
112134
def numpyro_schools_model(data, draws, chains):
113-
"""Centered eight schools implementation in NumPyro."""
135+
"""Non-centered eight schools implementation in NumPyro."""
114136
from jax.random import PRNGKey
115137
from numpyro.infer import MCMC, NUTS
116138

@@ -133,6 +155,35 @@ def numpyro_schools_model(data, draws, chains):
133155
return mcmc
134156

135157

158+
def numpyro_schools_model_svi(data, draws, chains):
159+
"""Non-centered eight schools implementation in NumPyro."""
160+
from jax.random import PRNGKey
161+
from numpyro.infer import SVI, Trace_ELBO, init_to_sample
162+
from numpyro.infer.autoguide import AutoNormal
163+
from numpyro.optim import Adam
164+
165+
guide = AutoNormal(_numpyro_noncentered_model, init_loc_fn=init_to_sample())
166+
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
167+
svi_result = svi.run(PRNGKey(0), 4000, **data)
168+
return {"svi": svi, "svi_result": svi_result, "model_kwargs": data}
169+
170+
171+
def numpyro_schools_model_svi_custom_guide(data, draws, chains):
172+
"""Non-centered eight schools implementation in NumPyro."""
173+
from jax.random import PRNGKey
174+
from numpyro.infer import SVI, Trace_ELBO
175+
from numpyro.optim import Adam
176+
177+
guide = _numpyro_noncentered_guide
178+
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
179+
svi_result = svi.run(PRNGKey(0), 4000, **data)
180+
return {
181+
"svi": svi,
182+
"svi_result": svi_result,
183+
"model_kwargs": data,
184+
}
185+
186+
136187
def pystan_noncentered_schools(data, draws, chains):
137188
"""Non-centered eight schools implementation for pystan."""
138189
schools_code = """
@@ -188,10 +239,12 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
188239
"""Load pystan, emcee, and pyro models from pickle."""
189240
here = os.path.dirname(os.path.abspath(__file__))
190241
supported = (
191-
# ("pystan", pystan_noncentered_schools),
192-
("emcee", emcee_schools_model),
193-
# ("pyro", pyro_noncentered_schools),
194-
("numpyro", numpyro_schools_model),
242+
# ("pystan", pystan_noncentered_schools, None),
243+
("emcee", emcee_schools_model, None),
244+
# ("pyro", pyro_noncentered_schools, None),
245+
("numpyro", numpyro_schools_model, None),
246+
("numpyro", numpyro_schools_model_svi, "numpyro_svi"),
247+
("numpyro", numpyro_schools_model_svi_custom_guide, "numpyro_svi_custom_guide"),
195248
)
196249
data_directory = os.path.join(here, "saved_models")
197250
if not os.path.isdir(data_directory):
@@ -201,7 +254,8 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
201254
if isinstance(libs, str):
202255
libs = [libs]
203256

204-
for library_name, func in supported:
257+
for library_name, func, addl_model_key in supported:
258+
model_key = addl_model_key or library_name
205259
if libs is not None and library_name not in libs:
206260
continue
207261
library = library_handle(library_name)
@@ -214,7 +268,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
214268

215269
py_version = sys.version_info
216270
fname = (
217-
f"{py_version.major}.{py_version.minor}_{library.__name__}_{library.__version__}"
271+
f"{py_version.major}.{py_version.minor}_{model_key}_{library.__version__}"
218272
f"_{sys.platform}_{draws}_{chains}.pkl.gzip"
219273
)
220274

@@ -225,11 +279,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
225279
_log.info("Generating and caching %s", fname)
226280
cloudpickle.dump(func(eight_schools_data, draws, chains), buff)
227281
except AttributeError as err:
228-
raise AttributeError(f"Failed caching {library_name}") from err
282+
raise AttributeError(f"Failed caching {model_key}") from err
229283

230284
with gzip.open(path, "rb") as buff:
231285
_log.info("Loading %s from cache", fname)
232-
models[library.__name__] = cloudpickle.load(buff)
286+
models[model_key] = cloudpickle.load(buff)
233287

234288
return models
235289

0 commit comments

Comments
 (0)