@@ -649,25 +649,50 @@ def guide():
649649 svi .update (svi_state )
650650
651651
652+ @pytest .mark .xfail (reason = "missing pattern in Funsor" )
653+ def test_collapse_diag_normal_plate_normal ():
654+ d = 3
655+ data = np .ones ((5 , d ))
656+
657+ def model ():
658+ x = numpyro .sample ("x" , dist .Normal (0 , 1 ))
659+ with handlers .collapse ():
660+ with handlers .plate ("data" , len (data )):
661+ y = numpyro .sample ("y" , dist .Normal (x , 1. ).expand ([d ]).to_event (1 ))
662+ numpyro .sample ("z" , dist .Normal (y , 1. ).to_event (1 ), obs = data )
663+
664+ def guide ():
665+ loc = numpyro .param ("loc" , 0. )
666+ scale = numpyro .param ("scale" , 1. , constraint = constraints .positive )
667+ numpyro .sample ("x" , dist .Normal (loc , scale ))
668+
669+ svi = SVI (model , guide , numpyro .optim .Adam (1 ), Trace_ELBO ())
670+ svi_state = svi .init (random .PRNGKey (0 ))
671+ svi .update (svi_state )
672+
673+
674+ @pytest .mark .xfail (reason = "missing pattern in Funsor" )
652675def test_collapse_normal_mvn_mvn ():
653676 T , d , S = 5 , 2 , 3
654677 data = jnp .ones ((T , S ))
655678
656679 def model ():
657- x = numpyro .sample ("x" , dist .Exponential ( 1 ))
680+ x = numpyro .sample ("x" , dist .Normal ( 0 , 1 ))
658681 with handlers .collapse ():
659682 with numpyro .plate ("d" , d , dim = - 1 ):
660- beta0 = numpyro .sample ("beta0" , dist .Normal (0. , 1. ).expand ([d , S ]).to_event (1 ))
683+ beta0 = numpyro .sample ("beta0" , dist .Normal (x , 1. ).expand ([d , S ]).to_event (1 ))
661684 beta = numpyro .sample (
662685 "beta" , dist .MultivariateNormal (beta0 , scale_tril = jnp .eye (S )))
663686
687+ # this fails because beta shape is (3,) while it should be (2, 3)
664688 mean = jnp .ones ((T , d )) @ beta
665689 with numpyro .plate ("data" , T , dim = - 1 ):
666690 numpyro .sample ("obs" , dist .MultivariateNormal (mean , scale_tril = jnp .eye (S )), obs = data )
667691
668692 def guide ():
669- rate = numpyro .param ("rate" , 1. , constraint = constraints .positive )
670- numpyro .sample ("x" , dist .Exponential (rate ))
693+ loc = numpyro .param ("loc" , 0. )
694+ scale = numpyro .param ("scale" , 1. , constraint = constraints .positive )
695+ numpyro .sample ("x" , dist .Normal (loc , scale ))
671696
672697 svi = SVI (model , guide , numpyro .optim .Adam (1 ), Trace_ELBO ())
673698 svi_state = svi .init (random .PRNGKey (0 ))
0 commit comments