Skip to content

Commit dae0282

Browse files
committed
added from_numpyro_svi support
updated typing for docstub fixed test_inferred_dims_univariate improved load_cached_models fixed svi docstring and SVIWrapper changed numpyro svi signature fixed from_numpyro_svi docstring added custom guide support and test added test support for custom guides
1 parent e3dccbf commit dae0282

File tree

6 files changed

+538
-77
lines changed

6 files changed

+538
-77
lines changed

external_tests/helpers.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,32 @@ def _numpyro_noncentered_model(J, sigma, y=None):
109109
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
110110

111111

112+
def _numpyro_noncentered_guide(J, sigma, y=None):
113+
import jax
114+
import numpyro
115+
import numpyro.distributions as dist
116+
117+
# Variational parameters for mu
118+
mu_loc = numpyro.param("mu_loc", 0.0)
119+
mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
120+
mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
121+
122+
# Variational parameters for tau (positive support)
123+
tau_loc = numpyro.param("tau_loc", 1.0)
124+
tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
125+
tau = numpyro.sample("tau", dist.LogNormal(jax.numpy.log(tau_loc), tau_scale))
126+
127+
# Variational parameters for eta
128+
eta_loc = numpyro.param("eta_loc", jax.numpy.zeros(J))
129+
eta_scale = numpyro.param("eta_scale", jax.numpy.ones(J), constraint=dist.constraints.positive)
130+
with numpyro.plate("J", J):
131+
eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))
132+
133+
# theta is deterministic; obs is handled in the model
134+
theta = mu + tau * eta
135+
return theta
136+
137+
112138
def numpyro_schools_model(data, draws, chains):
113139
"""Centered eight schools implementation in NumPyro."""
114140
from jax.random import PRNGKey
@@ -133,6 +159,36 @@ def numpyro_schools_model(data, draws, chains):
133159
return mcmc
134160

135161

162+
def numpyro_schools_model_svi(data, draws, chains):
163+
"""Centered eight schools implementation in NumPyro."""
164+
from jax.random import PRNGKey
165+
from numpyro.infer import SVI, Trace_ELBO, init_to_sample
166+
from numpyro.infer.autoguide import AutoNormal
167+
from numpyro.optim import Adam
168+
169+
guide = AutoNormal(_numpyro_noncentered_model, init_loc_fn=init_to_sample())
170+
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
171+
svi_result = svi.run(PRNGKey(0), 4000, **data)
172+
return {"guide": guide, "svi_result": svi_result, "model_kwargs": data}
173+
174+
175+
def numpyro_schools_model_svi_custom_guide(data, draws, chains):
176+
"""Centered eight schools implementation in NumPyro."""
177+
from jax.random import PRNGKey
178+
from numpyro.infer import SVI, Trace_ELBO
179+
from numpyro.optim import Adam
180+
181+
guide = _numpyro_noncentered_guide
182+
svi = SVI(_numpyro_noncentered_model, guide=guide, optim=Adam(0.05), loss=Trace_ELBO())
183+
svi_result = svi.run(PRNGKey(0), 4000, **data)
184+
return {
185+
"guide": guide,
186+
"svi_result": svi_result,
187+
"model_kwargs": data,
188+
"model": _numpyro_noncentered_model,
189+
}
190+
191+
136192
def pystan_noncentered_schools(data, draws, chains):
137193
"""Non-centered eight schools implementation for pystan."""
138194
schools_code = """
@@ -188,10 +244,12 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
188244
"""Load pystan, emcee, and pyro models from pickle."""
189245
here = os.path.dirname(os.path.abspath(__file__))
190246
supported = (
191-
# ("pystan", pystan_noncentered_schools),
192-
("emcee", emcee_schools_model),
193-
# ("pyro", pyro_noncentered_schools),
194-
("numpyro", numpyro_schools_model),
247+
# ("pystan", pystan_noncentered_schools, None),
248+
("emcee", emcee_schools_model, None),
249+
# ("pyro", pyro_noncentered_schools, None),
250+
("numpyro", numpyro_schools_model, None),
251+
("numpyro", numpyro_schools_model_svi, "numpyro_svi"),
252+
("numpyro", numpyro_schools_model_svi_custom_guide, "numpyro_svi_custom_guide"),
195253
)
196254
data_directory = os.path.join(here, "saved_models")
197255
if not os.path.isdir(data_directory):
@@ -201,7 +259,8 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
201259
if isinstance(libs, str):
202260
libs = [libs]
203261

204-
for library_name, func in supported:
262+
for library_name, func, addl_model_key in supported:
263+
model_key = addl_model_key or library_name
205264
if libs is not None and library_name not in libs:
206265
continue
207266
library = library_handle(library_name)
@@ -214,7 +273,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
214273

215274
py_version = sys.version_info
216275
fname = (
217-
f"{py_version.major}.{py_version.minor}_{library.__name__}_{library.__version__}"
276+
f"{py_version.major}.{py_version.minor}_{model_key}_{library.__version__}"
218277
f"_{sys.platform}_{draws}_{chains}.pkl.gzip"
219278
)
220279

@@ -225,11 +284,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
225284
_log.info("Generating and caching %s", fname)
226285
cloudpickle.dump(func(eight_schools_data, draws, chains), buff)
227286
except AttributeError as err:
228-
raise AttributeError(f"Failed caching {library_name}") from err
287+
raise AttributeError(f"Failed caching {model_key}") from err
229288

230289
with gzip.open(path, "rb") as buff:
231290
_log.info("Loading %s from cache", fname)
232-
models[library.__name__] = cloudpickle.load(buff)
291+
models[model_key] = cloudpickle.load(buff)
233292

234293
return models
235294

0 commit comments

Comments
 (0)