@@ -46,7 +46,9 @@ def predictions_data(self, data, predictions_params):
4646 )
4747 return predictions
4848
49- def get_inference_data (self , data , eight_schools_params , predictions_data , predictions_params ):
49+ def get_inference_data (
50+ self , data , eight_schools_params , predictions_data , predictions_params , infer_dims = False
51+ ):
5052 posterior_samples = data .obj .get_samples ()
5153 model = data .obj .sampler .model
5254 posterior_predictive = Predictive (model , posterior_samples )(
@@ -55,6 +57,12 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
5557 prior = Predictive (model , num_samples = 500 )(
5658 PRNGKey (2 ), eight_schools_params ["J" ], eight_schools_params ["sigma" ]
5759 )
60+ dims = {"theta" : ["school" ], "eta" : ["school" ], "obs" : ["school" ]}
61+ pred_dims = {"theta" : ["school_pred" ], "eta" : ["school_pred" ], "obs" : ["school_pred" ]}
62+ if infer_dims :
63+ dims = None
64+ pred_dims = None
65+
5866 predictions = predictions_data
5967 return from_numpyro (
6068 posterior = data .obj ,
@@ -65,8 +73,8 @@ def get_inference_data(self, data, eight_schools_params, predictions_data, predi
6573 "school" : np .arange (eight_schools_params ["J" ]),
6674 "school_pred" : np .arange (predictions_params ["J" ]),
6775 },
68- dims = { "theta" : [ "school" ], "eta" : [ "school" ], "obs" : [ "school" ]} ,
69- pred_dims = { "theta" : [ "school_pred" ], "eta" : [ "school_pred" ], "obs" : [ "school_pred" ]} ,
76+ dims = dims ,
77+ pred_dims = pred_dims ,
7078 )
7179
7280 def test_inference_data_namedtuple (self , data ):
@@ -77,6 +85,7 @@ def test_inference_data_namedtuple(self, data):
7785 data .obj .get_samples = lambda * args , ** kwargs : data_namedtuple
7886 inference_data = from_numpyro (
7987 posterior = data .obj ,
88+ dims = {}, # This mock test needs to turn off autodims like so or mock group_by_chain
8089 )
8190 assert isinstance (data .obj .get_samples (), Samples )
8291 data .obj .get_samples = _old_fn
@@ -282,3 +291,121 @@ def model():
282291 mcmc .run (PRNGKey (0 ))
283292 inference_data = from_numpyro (mcmc )
284293 assert inference_data .observed_data
294+
295+ def test_mcmc_infer_dims (self ):
296+ import numpyro
297+ import numpyro .distributions as dist
298+ from numpyro .infer import MCMC , NUTS
299+
300+ def model ():
301+ # note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
302+ with numpyro .plate ("group2" , 5 ), numpyro .plate ("group1" , 10 ):
303+ _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
304+
305+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
306+ mcmc .run (PRNGKey (0 ))
307+ inference_data = from_numpyro (
308+ mcmc , coords = {"group1" : np .arange (10 ), "group2" : np .arange (5 )}
309+ )
310+ assert inference_data .posterior .param .dims == ("chain" , "draw" , "group1" , "group2" )
311+ assert all (dim in inference_data .posterior .param .coords for dim in ("group1" , "group2" ))
312+
313+ def test_mcmc_infer_unsorted_dims (self ):
314+ import numpyro
315+ import numpyro .distributions as dist
316+ from numpyro .infer import MCMC , NUTS
317+
318+ def model ():
319+ group1_plate = numpyro .plate ("group1" , 10 , dim = - 1 )
320+ group2_plate = numpyro .plate ("group2" , 5 , dim = - 2 )
321+
322+ # the plate contexts are entered in a different order than the pre-defined dims
323+ # we should make sure this still works because the trace has all of the info it needs
324+ with group2_plate , group1_plate :
325+ _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
326+
327+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
328+ mcmc .run (PRNGKey (0 ))
329+ inference_data = from_numpyro (
330+ mcmc , coords = {"group1" : np .arange (10 ), "group2" : np .arange (5 )}
331+ )
332+ assert inference_data .posterior .param .dims == ("chain" , "draw" , "group2" , "group1" )
333+ assert all (dim in inference_data .posterior .param .coords for dim in ("group1" , "group2" ))
334+
335+ def test_mcmc_infer_dims_no_coords (self ):
336+ import numpyro
337+ import numpyro .distributions as dist
338+ from numpyro .infer import MCMC , NUTS
339+
340+ def model ():
341+ with numpyro .plate ("group" , 5 ):
342+ _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
343+
344+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
345+ mcmc .run (PRNGKey (0 ))
346+ inference_data = from_numpyro (mcmc )
347+ assert inference_data .posterior .param .dims == ("chain" , "draw" , "group" )
348+
349+ def test_mcmc_event_dims (self ):
350+ import numpyro
351+ import numpyro .distributions as dist
352+ from numpyro .infer import MCMC , NUTS
353+
354+ def model ():
355+ _ = numpyro .sample (
356+ "gamma" , dist .ZeroSumNormal (1 , event_shape = (10 ,)), infer = {"event_dims" : ["groups" ]}
357+ )
358+
359+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
360+ mcmc .run (PRNGKey (0 ))
361+ inference_data = from_numpyro (mcmc , coords = {"groups" : np .arange (10 )})
362+ assert inference_data .posterior .gamma .dims == ("chain" , "draw" , "groups" )
363+ assert "groups" in inference_data .posterior .gamma .coords
364+
365+ @pytest .mark .xfail
366+ def test_mcmc_inferred_dims_univariate (self ):
367+ import numpyro
368+ import numpyro .distributions as dist
369+ from numpyro .infer import MCMC , NUTS
370+ import jax .numpy as jnp
371+
372+ def model ():
373+ alpha = numpyro .sample ("alpha" , dist .Normal (0 , 1 ))
374+ sigma = numpyro .sample ("sigma" , dist .HalfNormal (1 ))
375+ with numpyro .plate ("obs_idx" , 3 ):
376+ # mu is plated by obs_idx, but isnt broadcasted to the plate shape
377+ # the expected behavior is that this should cause a failure
378+ mu = numpyro .deterministic ("mu" , alpha )
379+ return numpyro .sample ("y" , dist .Normal (mu , sigma ), obs = jnp .array ([- 1 , 0 , 1 ]))
380+
381+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
382+ mcmc .run (PRNGKey (0 ))
383+ inference_data = from_numpyro (mcmc , coords = {"obs_idx" : np .arange (3 )})
384+ assert inference_data .posterior .mu .dims == ("chain" , "draw" , "obs_idx" )
385+ assert "obs_idx" in inference_data .posterior .mu .coords
386+
387+ def test_mcmc_extra_event_dims (self ):
388+ import numpyro
389+ import numpyro .distributions as dist
390+ from numpyro .infer import MCMC , NUTS
391+
392+ def model ():
393+ gamma = numpyro .sample ("gamma" , dist .ZeroSumNormal (1 , event_shape = (10 ,)))
394+ _ = numpyro .deterministic ("gamma_plus1" , gamma + 1 )
395+
396+ mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
397+ mcmc .run (PRNGKey (0 ))
398+ inference_data = from_numpyro (
399+ mcmc , coords = {"groups" : np .arange (10 )}, extra_event_dims = {"gamma_plus1" : ["groups" ]}
400+ )
401+ assert inference_data .posterior .gamma_plus1 .dims == ("chain" , "draw" , "groups" )
402+ assert "groups" in inference_data .posterior .gamma_plus1 .coords
403+
404+ def test_mcmc_predictions_infer_dims (
405+ self , data , eight_schools_params , predictions_data , predictions_params
406+ ):
407+ inference_data = self .get_inference_data (
408+ data , eight_schools_params , predictions_data , predictions_params , infer_dims = True
409+ )
410+ assert inference_data .predictions .obs .dims == ("chain" , "draw" , "J" )
411+ assert "J" in inference_data .predictions .obs .coords
0 commit comments