Skip to content

Commit 4b0d4b0

Browse files
fritzoneerajprad
authored andcommitted
Bump to version 0.1.1 for release (#14)
* Remove f strings for Python 3.5 support * Add importorskip in to tests, in case a backend is missing * Bump to version 0.1.1
1 parent 487bb97 commit 4b0d4b0

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

pyroapi/dispatch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def guide():
2828
inference = infer.SVI(model, guide, adam, elbo)
2929
for step in range(10):
3030
loss = inference.step(*args, **kwargs)
31-
print(f"step {step} loss = {loss}")
31+
print("step {} loss = {}".format(step, loss))
3232
3333
"""
3434
import importlib
@@ -64,7 +64,8 @@ def __getattribute__(self, name):
6464
try:
6565
return getattr(module, name)
6666
except AttributeError:
67-
raise NotImplementedError(f'This Pyro backend does not implement {module_name}.{name}')
67+
raise NotImplementedError('This Pyro backend does not implement {}.{}'
68+
.format(module_name, name))
6869

6970

7071
@contextmanager

pyroapi/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.0'
1+
__version__ = '0.1.1'

test/test_dispatch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,19 @@
33
from pyroapi import handlers, infer, pyro, pyro_backend, register_backend
44
from pyroapi.testing import MODELS
55

6+
PACKAGE_NAME = {
7+
"pyro": "pyro",
8+
"minipyro": "pyro",
9+
"numpy": "numpyro",
10+
"funsor": "funsor",
11+
}
12+
613

714
@pytest.mark.filterwarnings("ignore", category=UserWarning)
815
@pytest.mark.parametrize('model', MODELS)
916
@pytest.mark.parametrize('backend', ['pyro', 'numpy'])
1017
def test_mcmc_interface(model, backend):
18+
pytest.importorskip(PACKAGE_NAME[backend])
1119
with pyro_backend(backend), handlers.seed(rng_seed=20):
1220
f = MODELS[model]()
1321
model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
@@ -19,6 +27,7 @@ def test_mcmc_interface(model, backend):
1927

2028
@pytest.mark.parametrize('backend', ['funsor', 'minipyro', 'numpy', 'pyro'])
2129
def test_not_implemented(backend):
30+
pytest.importorskip(PACKAGE_NAME[backend])
2231
with pyro_backend(backend):
2332
pyro.sample # should be implemented
2433
pyro.param # should be implemented
@@ -30,6 +39,7 @@ def test_not_implemented(backend):
3039
@pytest.mark.parametrize('backend', ['funsor', 'minipyro', 'numpy', 'pyro'])
3140
@pytest.mark.xfail(reason='Not supported by backend.')
3241
def test_model_sample(model, backend):
42+
pytest.importorskip(PACKAGE_NAME[backend])
3343
with pyro_backend(backend), handlers.seed(rng_seed=2):
3444
f = MODELS[model]()
3545
model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
@@ -44,6 +54,7 @@ def test_model_sample(model, backend):
4454
'pyro',
4555
])
4656
def test_trace_handler(model, backend):
57+
pytest.importorskip(PACKAGE_NAME[backend])
4758
with pyro_backend(backend), handlers.seed(rng_seed=2):
4859
f = MODELS[model]()
4960
model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
@@ -53,6 +64,7 @@ def test_trace_handler(model, backend):
5364

5465
@pytest.mark.parametrize('model', MODELS)
5566
def test_register_backend(model):
67+
pytest.importorskip("pyro")
5668
register_backend("foo", {
5769
"infer": "pyro.contrib.minipyro",
5870
"optim": "pyro.contrib.minipyro",

test/test_tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,16 @@
55

66
pytestmark = pytest.mark.filterwarnings("ignore::numpyro.compat.util.UnsupportedAPIWarning")
77

8+
PACKAGE_NAME = {
9+
"pyro": "pyro",
10+
"minipyro": "pyro",
11+
"numpy": "numpyro",
12+
"funsor": "funsor",
13+
}
14+
815

916
@pytest.fixture(params=["pyro", "minipyro", "numpy", "funsor"])
1017
def backend(request):
18+
pytest.importorskip(PACKAGE_NAME[request.param])
1119
with pyro_backend(request.param):
1220
yield

0 commit comments

Comments
 (0)