@@ -109,6 +109,32 @@ def _numpyro_noncentered_model(J, sigma, y=None):
109109 return numpyro .sample ("obs" , dist .Normal (theta , sigma ), obs = y )
110110
111111
112+ def _numpyro_noncentered_guide (J , sigma , y = None ):
113+ import jax
114+ import numpyro
115+ import numpyro .distributions as dist
116+
117+ # Variational parameters for mu
118+ mu_loc = numpyro .param ("mu_loc" , 0.0 )
119+ mu_scale = numpyro .param ("mu_scale" , 1.0 , constraint = dist .constraints .positive )
120+ mu = numpyro .sample ("mu" , dist .Normal (mu_loc , mu_scale ))
121+
122+ # Variational parameters for tau (positive support)
123+ tau_loc = numpyro .param ("tau_loc" , 1.0 )
124+ tau_scale = numpyro .param ("tau_scale" , 0.5 , constraint = dist .constraints .positive )
125+ tau = numpyro .sample ("tau" , dist .LogNormal (jax .numpy .log (tau_loc ), tau_scale ))
126+
127+ # Variational parameters for eta
128+ eta_loc = numpyro .param ("eta_loc" , jax .numpy .zeros (J ))
129+ eta_scale = numpyro .param ("eta_scale" , jax .numpy .ones (J ), constraint = dist .constraints .positive )
130+ with numpyro .plate ("J" , J ):
131+ eta = numpyro .sample ("eta" , dist .Normal (eta_loc , eta_scale ))
132+
133+ # theta is deterministic; obs is handled in the model
134+ theta = mu + tau * eta
135+ return theta
136+
137+
112138def numpyro_schools_model (data , draws , chains ):
113139 """Centered eight schools implementation in NumPyro."""
114140 from jax .random import PRNGKey
@@ -133,6 +159,36 @@ def numpyro_schools_model(data, draws, chains):
133159 return mcmc
134160
135161
162+ def numpyro_schools_model_svi (data , draws , chains ):
163+ """Centered eight schools implementation in NumPyro."""
164+ from jax .random import PRNGKey
165+ from numpyro .infer import SVI , Trace_ELBO , init_to_sample
166+ from numpyro .infer .autoguide import AutoNormal
167+ from numpyro .optim import Adam
168+
169+ guide = AutoNormal (_numpyro_noncentered_model , init_loc_fn = init_to_sample ())
170+ svi = SVI (_numpyro_noncentered_model , guide = guide , optim = Adam (0.05 ), loss = Trace_ELBO ())
171+ svi_result = svi .run (PRNGKey (0 ), 4000 , ** data )
172+ return {"guide" : guide , "svi_result" : svi_result , "model_kwargs" : data }
173+
174+
175+ def numpyro_schools_model_svi_custom_guide (data , draws , chains ):
176+ """Centered eight schools implementation in NumPyro."""
177+ from jax .random import PRNGKey
178+ from numpyro .infer import SVI , Trace_ELBO
179+ from numpyro .optim import Adam
180+
181+ guide = _numpyro_noncentered_guide
182+ svi = SVI (_numpyro_noncentered_model , guide = guide , optim = Adam (0.05 ), loss = Trace_ELBO ())
183+ svi_result = svi .run (PRNGKey (0 ), 4000 , ** data )
184+ return {
185+ "guide" : guide ,
186+ "svi_result" : svi_result ,
187+ "model_kwargs" : data ,
188+ "model" : _numpyro_noncentered_model ,
189+ }
190+
191+
136192def pystan_noncentered_schools (data , draws , chains ):
137193 """Non-centered eight schools implementation for pystan."""
138194 schools_code = """
@@ -188,10 +244,12 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
188244 """Load pystan, emcee, and pyro models from pickle."""
189245 here = os .path .dirname (os .path .abspath (__file__ ))
190246 supported = (
191- # ("pystan", pystan_noncentered_schools),
192- ("emcee" , emcee_schools_model ),
193- # ("pyro", pyro_noncentered_schools),
194- ("numpyro" , numpyro_schools_model ),
247+ # ("pystan", pystan_noncentered_schools, None),
248+ ("emcee" , emcee_schools_model , None ),
249+ # ("pyro", pyro_noncentered_schools, None),
250+ ("numpyro" , numpyro_schools_model , None ),
251+ ("numpyro" , numpyro_schools_model_svi , "numpyro_svi" ),
252+ ("numpyro" , numpyro_schools_model_svi_custom_guide , "numpyro_svi_custom_guide" ),
195253 )
196254 data_directory = os .path .join (here , "saved_models" )
197255 if not os .path .isdir (data_directory ):
@@ -201,7 +259,8 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
201259 if isinstance (libs , str ):
202260 libs = [libs ]
203261
204- for library_name , func in supported :
262+ for library_name , func , addl_model_key in supported :
263+ model_key = addl_model_key or library_name
205264 if libs is not None and library_name not in libs :
206265 continue
207266 library = library_handle (library_name )
@@ -214,7 +273,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
214273
215274 py_version = sys .version_info
216275 fname = (
217- f"{ py_version .major } .{ py_version .minor } _{ library . __name__ } _{ library .__version__ } "
276+ f"{ py_version .major } .{ py_version .minor } _{ model_key } _{ library .__version__ } "
218277 f"_{ sys .platform } _{ draws } _{ chains } .pkl.gzip"
219278 )
220279
@@ -225,11 +284,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
225284 _log .info ("Generating and caching %s" , fname )
226285 cloudpickle .dump (func (eight_schools_data , draws , chains ), buff )
227286 except AttributeError as err :
228- raise AttributeError (f"Failed caching { library_name } " ) from err
287+ raise AttributeError (f"Failed caching { model_key } " ) from err
229288
230289 with gzip .open (path , "rb" ) as buff :
231290 _log .info ("Loading %s from cache" , fname )
232- models [library . __name__ ] = cloudpickle .load (buff )
291+ models [model_key ] = cloudpickle .load (buff )
233292
234293 return models
235294
0 commit comments