6969)
7070def test_beta_bernoulli (auto_class ):
7171 data = jnp .array ([[1.0 ] * 8 + [0.0 ] * 2 , [1.0 ] * 4 + [0.0 ] * 6 ]).T
72+ N = len (data )
7273
7374 def model (data ):
74- f = numpyro .sample ("beta" , dist .Beta (jnp .ones (2 ), jnp .ones (2 )))
75- numpyro .sample ("obs" , dist .Bernoulli (f ), obs = data )
75+ f = numpyro .sample ("beta" , dist .Beta (jnp .ones (2 ), jnp .ones (2 )).to_event ())
76+ with numpyro .plate ("N" , N ):
77+ numpyro .sample ("obs" , dist .Bernoulli (f ).to_event (1 ), obs = data )
7678
7779 adam = optim .Adam (0.01 )
7880 if auto_class == AutoDAIS :
@@ -104,12 +106,12 @@ def body_fn(i, val):
104106 # Predictive can be instantiated from posterior samples...
105107 predictive = Predictive (model , posterior_samples = posterior_samples )
106108 predictive_samples = predictive (random .PRNGKey (1 ), None )
107- assert predictive_samples ["obs" ].shape == (1000 , 2 )
109+ assert predictive_samples ["obs" ].shape == (1000 , N , 2 )
108110
109111 # ... or from the guide + params
110112 predictive = Predictive (model , guide = guide , params = params , num_samples = 1000 )
111113 predictive_samples = predictive (random .PRNGKey (1 ), None )
112- assert predictive_samples ["obs" ].shape == (1000 , 2 )
114+ assert predictive_samples ["obs" ].shape == (1000 , N , 2 )
113115
114116
115117@pytest .mark .parametrize (
@@ -135,9 +137,10 @@ def test_logistic_regression(auto_class, Elbo):
135137 labels = dist .Bernoulli (logits = logits ).sample (random .PRNGKey (1 ))
136138
137139 def model (data , labels ):
138- coefs = numpyro .sample ("coefs" , dist .Normal (jnp . zeros ( dim ), jnp . ones ( dim ) ))
140+ coefs = numpyro .sample ("coefs" , dist .Normal (0 , 1 ). expand ([ dim ]). to_event ( ))
139141 logits = numpyro .deterministic ("logits" , jnp .sum (coefs * data , axis = - 1 ))
140- return numpyro .sample ("obs" , dist .Bernoulli (logits = logits ), obs = labels )
142+ with numpyro .plate ("N" , len (data )):
143+ return numpyro .sample ("obs" , dist .Bernoulli (logits = logits ), obs = labels )
141144
142145 adam = optim .Adam (0.01 )
143146 rng_key_init = random .PRNGKey (1 )
@@ -242,7 +245,8 @@ def model(data):
242245 dist .Uniform (0 , 1 ), transforms .AffineTransform (0 , alpha )
243246 ),
244247 )
245- numpyro .sample ("obs" , dist .Normal (loc , 0.1 ), obs = data )
248+ with numpyro .plate ("N" , len (data )):
249+ numpyro .sample ("obs" , dist .Normal (loc , 0.1 ), obs = data )
246250
247251 adam = optim .Adam (0.01 )
248252 rng_key_init = random .PRNGKey (1 )
@@ -317,12 +321,14 @@ def actual_model(data):
317321 dist .Uniform (0 , 1 ), transforms .AffineTransform (0 , alpha )
318322 ),
319323 )
320- numpyro .sample ("obs" , dist .Normal (loc , 0.1 ), obs = data )
324+ with numpyro .plate ("N" , len (data )):
325+ numpyro .sample ("obs" , dist .Normal (loc , 0.1 ), obs = data )
321326
322327 def expected_model (data ):
323328 alpha = numpyro .sample ("alpha" , dist .Uniform (0 , 1 ))
324329 loc = numpyro .sample ("loc" , dist .Uniform (0 , 1 )) * alpha
325- numpyro .sample ("obs" , dist .Normal (loc , 0.1 ), obs = data )
330+ with numpyro .plate ("N" , len (data )):
331+ numpyro .sample ("obs" , dist .Normal (loc , 0.1 ), obs = data )
326332
327333 adam = optim .Adam (0.01 )
328334 rng_key_init = random .PRNGKey (1 )
@@ -355,9 +361,10 @@ def expected_model(data):
355361def test_laplace_approximation_warning ():
356362 def model (x , y ):
357363 a = numpyro .sample ("a" , dist .Normal (0 , 10 ))
358- b = numpyro .sample ("b" , dist .Normal (0 , 10 ), sample_shape = ( 3 , ))
364+ b = numpyro .sample ("b" , dist .Normal (0 , 10 ). expand ([ 3 ]). to_event ( ))
359365 mu = a + b [0 ] * x + b [1 ] * x ** 2 + b [2 ] * x ** 3
360- numpyro .sample ("y" , dist .Normal (mu , 0.001 ), obs = y )
366+ with numpyro .plate ("N" , len (x )):
367+ numpyro .sample ("y" , dist .Normal (mu , 0.001 ), obs = y )
361368
362369 x = random .normal (random .PRNGKey (0 ), (3 ,))
363370 y = 1 + 2 * x + 3 * x ** 2 + 4 * x ** 3
@@ -375,7 +382,8 @@ def model(x, y):
375382 a = numpyro .sample ("a" , dist .Normal (0 , 10 ))
376383 b = numpyro .sample ("b" , dist .Normal (0 , 10 ))
377384 mu = a + b * x
378- numpyro .sample ("y" , dist .Normal (mu , 1 ), obs = y )
385+ with numpyro .plate ("N" , len (x )):
386+ numpyro .sample ("y" , dist .Normal (mu , 1 ), obs = y )
379387
380388 x = random .normal (random .PRNGKey (0 ), (100 ,))
381389 y = 1 + 2 * x
@@ -401,7 +409,8 @@ def model(y):
401409 "sigma" , dist .ImproperUniform (dist .constraints .positive , (), ())
402410 )
403411 mu = numpyro .deterministic ("mu" , lambda1 + lambda2 )
404- numpyro .sample ("y" , dist .Normal (mu , sigma ), obs = y )
412+ with numpyro .plate ("N" , len (y )):
413+ numpyro .sample ("y" , dist .Normal (mu , sigma ), obs = y )
405414
406415 guide = AutoDiagonalNormal (model )
407416 svi = SVI (model , guide , optim .Adam (0.003 ), Trace_ELBO (), y = y )
@@ -417,7 +426,8 @@ def model(x, y):
417426 nn = numpyro .module ("nn" , Dense (1 ), (10 ,))
418427 mu = nn (x ).squeeze (- 1 )
419428 sigma = numpyro .sample ("sigma" , dist .HalfNormal (1 ))
420- numpyro .sample ("y" , dist .Normal (mu , sigma ), obs = y )
429+ with numpyro .plate ("N" , len (y )):
430+ numpyro .sample ("y" , dist .Normal (mu , sigma ), obs = y )
421431
422432 guide = AutoDiagonalNormal (model )
423433 svi = SVI (model , guide , optim .Adam (0.003 ), Trace_ELBO (), x = x , y = y )
@@ -497,7 +507,8 @@ def model(y=None):
497507 mu = numpyro .sample ("mu" , dist .Normal (0 , 5 ))
498508 sigma = numpyro .param ("sigma" , 1 , constraint = constraints .positive )
499509
500- y = numpyro .sample ("y" , dist .Normal (mu , sigma ).expand ((n ,)), obs = y )
510+ with numpyro .plate ("N" , len (y )):
511+ y = numpyro .sample ("y" , dist .Normal (mu , sigma ).expand ((n ,)), obs = y )
501512 numpyro .deterministic ("z" , (y - mu ) / sigma )
502513
503514 mu , sigma = 2 , 3
0 commit comments