33from pyroapi import handlers , infer , pyro , pyro_backend , register_backend
44from pyroapi .testing import MODELS
55
6+ PACKAGE_NAME = {
7+ "pyro" : "pyro" ,
8+ "minipyro" : "pyro" ,
9+ "numpy" : "numpyro" ,
10+ "funsor" : "funsor" ,
11+ }
12+
613
714@pytest .mark .filterwarnings ("ignore" , category = UserWarning )
815@pytest .mark .parametrize ('model' , MODELS )
916@pytest .mark .parametrize ('backend' , ['pyro' , 'numpy' ])
1017def test_mcmc_interface (model , backend ):
18+ pytest .importorskip (PACKAGE_NAME [backend ])
1119 with pyro_backend (backend ), handlers .seed (rng_seed = 20 ):
1220 f = MODELS [model ]()
1321 model , args , kwargs = f ['model' ], f .get ('model_args' , ()), f .get ('model_kwargs' , {})
@@ -19,6 +27,7 @@ def test_mcmc_interface(model, backend):
1927
2028@pytest .mark .parametrize ('backend' , ['funsor' , 'minipyro' , 'numpy' , 'pyro' ])
2129def test_not_implemented (backend ):
30+ pytest .importorskip (PACKAGE_NAME [backend ])
2231 with pyro_backend (backend ):
2332 pyro .sample # should be implemented
2433 pyro .param # should be implemented
@@ -30,6 +39,7 @@ def test_not_implemented(backend):
3039@pytest .mark .parametrize ('backend' , ['funsor' , 'minipyro' , 'numpy' , 'pyro' ])
3140@pytest .mark .xfail (reason = 'Not supported by backend.' )
3241def test_model_sample (model , backend ):
42+ pytest .importorskip (PACKAGE_NAME [backend ])
3343 with pyro_backend (backend ), handlers .seed (rng_seed = 2 ):
3444 f = MODELS [model ]()
3545 model , model_args , model_kwargs = f ['model' ], f .get ('model_args' , ()), f .get ('model_kwargs' , {})
@@ -44,6 +54,7 @@ def test_model_sample(model, backend):
4454 'pyro' ,
4555])
4656def test_trace_handler (model , backend ):
57+ pytest .importorskip (PACKAGE_NAME [backend ])
4758 with pyro_backend (backend ), handlers .seed (rng_seed = 2 ):
4859 f = MODELS [model ]()
4960 model , model_args , model_kwargs = f ['model' ], f .get ('model_args' , ()), f .get ('model_kwargs' , {})
@@ -53,6 +64,7 @@ def test_trace_handler(model, backend):
5364
5465@pytest .mark .parametrize ('model' , MODELS )
5566def test_register_backend (model ):
67+ pytest .importorskip ("pyro" )
5668 register_backend ("foo" , {
5769 "infer" : "pyro.contrib.minipyro" ,
5870 "optim" : "pyro.contrib.minipyro" ,
0 commit comments