@@ -43,7 +43,9 @@ def predictions_data(self, data, predictions_params):
4343 )
4444 return predictions
4545
46- def get_inference_data (self , data , eight_schools_params , predictions_data , predictions_params ):
46+ def get_inference_data (
47+ self , data , eight_schools_params , predictions_data , predictions_params , infer_dims = False
48+ ):
4749 posterior_samples = data .obj .get_samples ()
4850 model = data .obj .sampler .model
4951 posterior_predictive = Predictive (model , posterior_samples )(
@@ -52,6 +54,11 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
5254 prior = Predictive (model , num_samples = 500 )(
5355 PRNGKey (2 ), eight_schools_params ["J" ], eight_schools_params ["sigma" ]
5456 )
57+ dims = {"theta" : ["school" ], "eta" : ["school" ], "obs" : ["school" ]}
58+ pred_dims = {"theta" : ["school_pred" ], "eta" : ["school_pred" ], "obs" : ["school_pred" ]}
59+ if infer_dims :
60+ dims = pred_dims = None
61+
5562 predictions = predictions_data
5663 return from_numpyro (
5764 posterior = data .obj ,
@@ -62,8 +69,8 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
6269 "school" : np .arange (eight_schools_params ["J" ]),
6370 "school_pred" : np .arange (predictions_params ["J" ]),
6471 },
65- dims = { "theta" : [ "school" ], "eta" : [ "school" ], "obs" : [ "school" ]} ,
66- pred_dims = { "theta" : [ "school_pred" ], "eta" : [ "school_pred" ], "obs" : [ "school_pred" ]} ,
72+ dims = dims ,
73+ pred_dims = pred_dims ,
6774 )
6875
6976 def test_inference_data_namedtuple (self , data ):
@@ -74,6 +81,7 @@ def test_inference_data_namedtuple(self, data):
7481 data .obj .get_samples = lambda * args , ** kwargs : data_namedtuple
7582 inference_data = from_numpyro (
7683 posterior = data .obj ,
84+ dims = {}, # This mock test needs to turn off autodims like so or mock group_by_chain
7785 )
7886 assert isinstance (data .obj .get_samples (), Samples )
7987 data .obj .get_samples = _old_fn
@@ -273,3 +281,119 @@ def model():
273281 mcmc .run (PRNGKey (0 ))
274282 inference_data = from_numpyro (mcmc )
275283 assert inference_data .observed_data
284+
285+ def test_mcmc_infer_dims (self ):
286+ import numpyro
287+ import numpyro .distributions as dist
288+ from numpyro .infer import MCMC , NUTS
289+
290+ def model ():
291+ # note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
292+ with numpyro .plate ("group2" , 5 ), numpyro .plate ("group1" , 10 ):
293+ _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
294+
295+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
296+ mcmc .run (PRNGKey (0 ))
297+ inference_data = from_numpyro (
298+ mcmc , coords = {"group1" : np .arange (10 ), "group2" : np .arange (5 )}
299+ )
300+ assert inference_data .posterior .param .dims == ("chain" , "draw" , "group1" , "group2" )
301+ assert all (dim in inference_data .posterior .param .coords for dim in ("group1" , "group2" ))
302+
303+ def test_mcmc_infer_unsorted_dims (self ):
304+ import numpyro
305+ import numpyro .distributions as dist
306+ from numpyro .infer import MCMC , NUTS
307+
308+ def model ():
309+ group1_plate = numpyro .plate ("group1" , 10 , dim = - 1 )
310+ group2_plate = numpyro .plate ("group2" , 5 , dim = - 2 )
311+
312+ # the plate contexts are entered in a different order than the pre-defined dims
313+ # we should make sure this still works because the trace has all of the info it needs
314+ with group2_plate , group1_plate :
315+ _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
316+
317+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
318+ mcmc .run (PRNGKey (0 ))
319+ inference_data = from_numpyro (
320+ mcmc , coords = {"group1" : np .arange (10 ), "group2" : np .arange (5 )}
321+ )
322+ assert inference_data .posterior .param .dims == ("chain" , "draw" , "group2" , "group1" )
323+ assert all (dim in inference_data .posterior .param .coords for dim in ("group1" , "group2" ))
324+
325+ def test_mcmc_infer_dims_no_coords (self ):
326+ import numpyro
327+ import numpyro .distributions as dist
328+ from numpyro .infer import MCMC , NUTS
329+
330+ def model ():
331+ with numpyro .plate ("group" , 5 ):
332+ _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
333+
334+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
335+ mcmc .run (PRNGKey (0 ))
336+ inference_data = from_numpyro (mcmc )
337+ assert inference_data .posterior .param .dims == ("chain" , "draw" , "group" )
338+
339+ def test_mcmc_event_dims (self ):
340+ import numpyro
341+ import numpyro .distributions as dist
342+ from numpyro .infer import MCMC , NUTS
343+
344+ def model ():
345+ _ = numpyro .sample (
346+ "gamma" , dist .ZeroSumNormal (1 , event_shape = (10 ,)), infer = {"event_dims" : ["groups" ]}
347+ )
348+
349+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
350+ mcmc .run (PRNGKey (0 ))
351+ inference_data = from_numpyro (mcmc , coords = {"groups" : np .arange (10 )})
352+ assert inference_data .posterior .gamma .dims == ("chain" , "draw" , "groups" )
353+ assert "groups" in inference_data .posterior .gamma .coords
354+
355+ def test_mcmc_inferred_dims_univariate (self ):
356+ import jax .numpy as jnp
357+ import numpyro
358+ import numpyro .distributions as dist
359+ from numpyro .infer import MCMC , NUTS
360+
361+ def model ():
362+ alpha = numpyro .sample ("alpha" , dist .Normal (0 , 1 ))
363+ sigma = numpyro .sample ("sigma" , dist .HalfNormal (1 ))
364+ with numpyro .plate ("obs_idx" , 3 ):
365+ # mu is plated by obs_idx, but isnt broadcasted to the plate shape
366+ # the expected behavior is that this should cause a failure
367+ mu = numpyro .deterministic ("mu" , alpha )
368+ return numpyro .sample ("y" , dist .Normal (mu , sigma ), obs = jnp .array ([- 1 , 0 , 1 ]))
369+
370+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
371+ mcmc .run (PRNGKey (0 ))
372+ with pytest .raises (ValueError ):
373+ from_numpyro (mcmc , coords = {"obs_idx" : np .arange (3 )})
374+
375+ def test_mcmc_extra_event_dims (self ):
376+ import numpyro
377+ import numpyro .distributions as dist
378+ from numpyro .infer import MCMC , NUTS
379+
380+ def model ():
381+ gamma = numpyro .sample ("gamma" , dist .ZeroSumNormal (1 , event_shape = (10 ,)))
382+ _ = numpyro .deterministic ("gamma_plus1" , gamma + 1 )
383+
384+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
385+ mcmc .run (PRNGKey (0 ))
386+ inference_data = from_numpyro (
387+ mcmc , coords = {"groups" : np .arange (10 )}, extra_event_dims = {"gamma_plus1" : ["groups" ]}
388+ )
389+ assert inference_data .posterior .gamma_plus1 .dims == ("chain" , "draw" , "groups" )
390+ assert "groups" in inference_data .posterior .gamma_plus1 .coords
391+
392+ def test_mcmc_predictions_infer_dims (
393+ self , data , eight_schools_params , predictions_data , predictions_params
394+ ):
395+ inference_data = self .get_inference_data (
396+ data , eight_schools_params , predictions_data , predictions_params , infer_dims = True
397+ )
398+ assert inference_data .predictions .obs .dims == ("chain" , "draw" , "J" )
399+ assert "J" in inference_data .predictions .obs .coords
0 commit comments