@@ -109,8 +109,30 @@ def _numpyro_noncentered_model(J, sigma, y=None):
109109 return numpyro .sample ("obs" , dist .Normal (theta , sigma ), obs = y )
110110
111111
112+ def _numpyro_noncentered_guide (J , sigma , y = None ):
113+ import jax
114+ import numpyro
115+ import numpyro .distributions as dist
116+
117+ # Variational parameters for mu
118+ mu_loc = numpyro .param ("mu_loc" , 0.0 )
119+ mu_scale = numpyro .param ("mu_scale" , 1.0 , constraint = dist .constraints .positive )
120+ numpyro .sample ("mu" , dist .Normal (mu_loc , mu_scale ))
121+
122+ # Variational parameters for tau (positive support)
123+ tau_loc = numpyro .param ("tau_loc" , 1.0 )
124+ tau_scale = numpyro .param ("tau_scale" , 0.5 , constraint = dist .constraints .positive )
125+ numpyro .sample ("tau" , dist .LogNormal (jax .numpy .log (tau_loc ), tau_scale ))
126+
127+ # Variational parameters for eta
128+ eta_loc = numpyro .param ("eta_loc" , jax .numpy .zeros (J ))
129+ eta_scale = numpyro .param ("eta_scale" , jax .numpy .ones (J ), constraint = dist .constraints .positive )
130+ with numpyro .plate ("J" , J ):
131+ numpyro .sample ("eta" , dist .Normal (eta_loc , eta_scale ))
132+
133+
112134def numpyro_schools_model (data , draws , chains ):
113- """Centered eight schools implementation in NumPyro."""
135+ """Non-centered eight schools implementation in NumPyro."""
114136 from jax .random import PRNGKey
115137 from numpyro .infer import MCMC , NUTS
116138
@@ -133,6 +155,35 @@ def numpyro_schools_model(data, draws, chains):
133155 return mcmc
134156
135157
158+ def numpyro_schools_model_svi (data , draws , chains ):
159+ """Non-centered eight schools implementation in NumPyro."""
160+ from jax .random import PRNGKey
161+ from numpyro .infer import SVI , Trace_ELBO , init_to_sample
162+ from numpyro .infer .autoguide import AutoNormal
163+ from numpyro .optim import Adam
164+
165+ guide = AutoNormal (_numpyro_noncentered_model , init_loc_fn = init_to_sample ())
166+ svi = SVI (_numpyro_noncentered_model , guide = guide , optim = Adam (0.05 ), loss = Trace_ELBO ())
167+ svi_result = svi .run (PRNGKey (0 ), 4000 , ** data )
168+ return {"svi" : svi , "svi_result" : svi_result , "model_kwargs" : data }
169+
170+
171+ def numpyro_schools_model_svi_custom_guide (data , draws , chains ):
172+ """Non-centered eight schools implementation in NumPyro."""
173+ from jax .random import PRNGKey
174+ from numpyro .infer import SVI , Trace_ELBO
175+ from numpyro .optim import Adam
176+
177+ guide = _numpyro_noncentered_guide
178+ svi = SVI (_numpyro_noncentered_model , guide = guide , optim = Adam (0.05 ), loss = Trace_ELBO ())
179+ svi_result = svi .run (PRNGKey (0 ), 4000 , ** data )
180+ return {
181+ "svi" : svi ,
182+ "svi_result" : svi_result ,
183+ "model_kwargs" : data ,
184+ }
185+
186+
136187def pystan_noncentered_schools (data , draws , chains ):
137188 """Non-centered eight schools implementation for pystan."""
138189 schools_code = """
@@ -188,10 +239,12 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
188239 """Load pystan, emcee, and pyro models from pickle."""
189240 here = os .path .dirname (os .path .abspath (__file__ ))
190241 supported = (
191- # ("pystan", pystan_noncentered_schools),
192- ("emcee" , emcee_schools_model ),
193- # ("pyro", pyro_noncentered_schools),
194- ("numpyro" , numpyro_schools_model ),
242+ # ("pystan", pystan_noncentered_schools, None),
243+ ("emcee" , emcee_schools_model , None ),
244+ # ("pyro", pyro_noncentered_schools, None),
245+ ("numpyro" , numpyro_schools_model , None ),
246+ ("numpyro" , numpyro_schools_model_svi , "numpyro_svi" ),
247+ ("numpyro" , numpyro_schools_model_svi_custom_guide , "numpyro_svi_custom_guide" ),
195248 )
196249 data_directory = os .path .join (here , "saved_models" )
197250 if not os .path .isdir (data_directory ):
@@ -201,7 +254,8 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
201254 if isinstance (libs , str ):
202255 libs = [libs ]
203256
204- for library_name , func in supported :
257+ for library_name , func , addl_model_key in supported :
258+ model_key = addl_model_key or library_name
205259 if libs is not None and library_name not in libs :
206260 continue
207261 library = library_handle (library_name )
@@ -214,7 +268,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
214268
215269 py_version = sys .version_info
216270 fname = (
217- f"{ py_version .major } .{ py_version .minor } _{ library . __name__ } _{ library .__version__ } "
271+ f"{ py_version .major } .{ py_version .minor } _{ model_key } _{ library .__version__ } "
218272 f"_{ sys .platform } _{ draws } _{ chains } .pkl.gzip"
219273 )
220274
@@ -225,11 +279,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
225279 _log .info ("Generating and caching %s" , fname )
226280 cloudpickle .dump (func (eight_schools_data , draws , chains ), buff )
227281 except AttributeError as err :
228- raise AttributeError (f"Failed caching { library_name } " ) from err
282+ raise AttributeError (f"Failed caching { model_key } " ) from err
229283
230284 with gzip .open (path , "rb" ) as buff :
231285 _log .info ("Loading %s from cache" , fname )
232- models [library . __name__ ] = cloudpickle .load (buff )
286+ models [model_key ] = cloudpickle .load (buff )
233287
234288 return models
235289
0 commit comments