1414PRNGKey = jax .random .PRNGKey
1515numpyro = importorskip ("numpyro" )
1616Predictive = numpyro .infer .Predictive
17- numpyro .set_host_device_count (2 )
18- dist = numpyro .distributions
19- AutoNormal = numpyro .infer .autoguide .AutoNormal
20- AutoDelta = numpyro .infer .autoguide .AutoDelta
2117autoguide = numpyro .infer .autoguide
18+ numpyro .set_host_device_count (2 )
2219
2320
2421class TestDataNumPyro :
@@ -309,12 +306,13 @@ def model():
309306 "svi,guide_fn" ,
310307 [
311308 (False , None ), # MCMC, guide ignored
312- (True , AutoDelta ), # SVI with AutoDelta
313- (True , AutoNormal ), # SVI with AutoNormal
309+ (True , autoguide . AutoDelta ), # SVI with AutoDelta
310+ (True , autoguide . AutoNormal ), # SVI with AutoNormal
314311 (True , "custom" ), # SVI with custom guide
315312 ],
316313 )
317314 def test_infer_dims (self , svi , guide_fn ):
315+ import jax .numpy as jnp
318316 import numpyro
319317 import numpyro .distributions as dist
320318
@@ -324,9 +322,9 @@ def model():
324322 _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
325323
326324 def guide ():
327- loc = numpyro .param ("param_loc" , jax . numpy .zeros ((10 , 5 )))
325+ loc = numpyro .param ("param_loc" , jnp .zeros ((10 , 5 )))
328326 scale = numpyro .param (
329- "param_scale" , jax . numpy .ones ((10 , 5 )), constraint = dist .constraints .positive
327+ "param_scale" , jnp .ones ((10 , 5 )), constraint = dist .constraints .positive
330328 )
331329 with numpyro .plate ("group2" , 5 ), numpyro .plate ("group1" , 10 ):
332330 numpyro .sample ("param" , dist .Normal (loc , scale ))
@@ -348,12 +346,13 @@ def guide():
348346 "svi,guide_fn" ,
349347 [
350348 (False , None ), # MCMC, guide ignored
351- (True , AutoDelta ), # SVI with AutoDelta
352- (True , AutoNormal ), # SVI with AutoNormal
349+ (True , autoguide . AutoDelta ), # SVI with AutoDelta
350+ (True , autoguide . AutoNormal ), # SVI with AutoNormal
353351 (True , "custom" ), # SVI with custom guide
354352 ],
355353 )
356354 def test_infer_unsorted_dims (self , svi , guide_fn ):
355+ import jax .numpy as jnp
357356 import numpyro
358357 import numpyro .distributions as dist
359358
@@ -367,9 +366,9 @@ def model():
367366 _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
368367
369368 def guide ():
370- loc = numpyro .param ("param_loc" , jax . numpy .zeros ((5 , 10 )))
369+ loc = numpyro .param ("param_loc" , jnp .zeros ((5 , 10 )))
371370 scale = numpyro .param (
372- "param_scale" , jax . numpy .ones ((5 , 10 )), constraint = dist .constraints .positive
371+ "param_scale" , jnp .ones ((5 , 10 )), constraint = dist .constraints .positive
373372 )
374373 group1_plate = numpyro .plate ("group1" , 10 , dim = - 1 )
375374 group2_plate = numpyro .plate ("group2" , 5 , dim = - 2 )
@@ -393,12 +392,13 @@ def guide():
393392 "svi,guide_fn" ,
394393 [
395394 (False , None ), # MCMC, guide ignored
396- (True , AutoDelta ), # SVI with AutoDelta
397- (True , AutoNormal ), # SVI with AutoNormal
395+ (True , autoguide . AutoDelta ), # SVI with AutoDelta
396+ (True , autoguide . AutoNormal ), # SVI with AutoNormal
398397 (True , "custom" ), # SVI with custom guide
399398 ],
400399 )
401400 def test_infer_dims_no_coords (self , svi , guide_fn ):
401+ import jax .numpy as jnp
402402 import numpyro
403403 import numpyro .distributions as dist
404404
@@ -407,10 +407,8 @@ def model():
407407 _ = numpyro .sample ("param" , dist .Normal (0 , 1 ))
408408
409409 def guide ():
410- loc = numpyro .param ("param_loc" , jax .numpy .zeros (5 ))
411- scale = numpyro .param (
412- "param_scale" , jax .numpy .ones (5 ), constraint = dist .constraints .positive
413- )
410+ loc = numpyro .param ("param_loc" , jnp .zeros (5 ))
411+ scale = numpyro .param ("param_scale" , jnp .ones (5 ), constraint = dist .constraints .positive )
414412 with numpyro .plate ("group" , 5 ):
415413 numpyro .sample ("param" , dist .Normal (loc , scale ))
416414
@@ -428,8 +426,8 @@ def guide():
428426 "svi,guide_fn" ,
429427 [
430428 (False , None ), # MCMC, guide ignored
431- (True , AutoDelta ), # SVI with AutoDelta
432- (True , AutoNormal ), # SVI with AutoNormal
429+ (True , autoguide . AutoDelta ), # SVI with AutoDelta
430+ (True , autoguide . AutoNormal ), # SVI with AutoNormal
433431 (True , "custom" ), # SVI with custom guide
434432 ],
435433 )
@@ -465,8 +463,8 @@ def guide():
465463 "svi,guide_fn" ,
466464 [
467465 (False , None ), # MCMC, guide ignored
468- (True , AutoDelta ), # SVI with AutoDelta
469- (True , AutoNormal ), # SVI with AutoNormal
466+ (True , autoguide . AutoDelta ), # SVI with AutoDelta
467+ (True , autoguide . AutoNormal ), # SVI with AutoNormal
470468 (True , "custom" ), # SVI with custom guide
471469 ],
472470 )
@@ -510,8 +508,8 @@ def guide():
510508 "svi,guide_fn" ,
511509 [
512510 (False , None ), # MCMC, guide ignored
513- (True , AutoDelta ), # SVI with AutoDelta
514- (True , AutoNormal ), # SVI with AutoNormal
511+ (True , autoguide . AutoDelta ), # SVI with AutoDelta
512+ (True , autoguide . AutoNormal ), # SVI with AutoNormal
515513 (True , "custom" ), # SVI with custom guide
516514 ],
517515 )
@@ -554,7 +552,7 @@ def test_predictions_infer_dims(
554552 assert inference_data .predictions .obs .dims == (sample_dims + ("J" ,))
555553 assert "J" in inference_data .predictions .obs .coords
556554
557- def _run_inference (self , model , svi , guide_fn = autoguide . AutoNormal ):
555+ def _run_inference (self , model , svi , guide_fn ):
558556 from numpyro .infer import MCMC , NUTS , SVI , Trace_ELBO
559557 from numpyro .optim import Adam
560558
0 commit comments