Skip to content

Commit ef1fe92

Browse files
fritzoneerajprad
authored andcommitted
Add register_backend() function (#11)
1 parent 5cb6b62 commit ef1fe92

File tree

4 files changed

+116
-70
lines changed

4 files changed

+116
-70
lines changed

docs/source/dispatch.rst

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,9 @@
11
Dispatch
22
========
33

4-
It's easiest to see how to use pyroapi by example:
5-
6-
.. code-block:: python
7-
8-
from pyroapi import distributions as dist
9-
from pyroapi import infer, ops, optim, pyro, pyro_backend
10-
11-
# These model and guide are backend-agnostic.
12-
def model():
13-
locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5]))
14-
p = ops.tensor([0.2, 0.3, 0.5])
15-
with pyro.plate("plate", len(data), dim=-1):
16-
x = pyro.sample("x", dist.Categorical(p))
17-
pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
18-
19-
def guide():
20-
p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2]))
21-
with pyro.plate("plate", len(data), dim=-1):
22-
pyro.sample("x", dist.Categorical(p))
23-
24-
# We can now set a backend at inference time.
25-
with pyro_backend("numpyro"):
26-
elbo = infer.Trace_ELBO(ignore_jit_warnings=True)
27-
adam = optim.Adam({"lr": 1e-6})
28-
inference = infer.SVI(model, guide, adam, elbo)
29-
for step in range(10):
30-
loss = inference.step(*args, **kwargs)
31-
print(f"step {step} loss = {loss}")
32-
334
.. automodule:: pyroapi.dispatch
345
.. autofunction:: pyroapi.dispatch.pyro_backend
6+
.. autofunction:: pyroapi.dispatch.register_backend
357

368
Generic Modules
379
---------------

pyroapi/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pyroapi.dispatch import distributions, handlers, infer, ops, optim, pyro, pyro_backend
1+
from pyroapi.dispatch import distributions, handlers, infer, ops, optim, pyro, pyro_backend, register_backend
22

33
__all__ = [
44
'distributions',
@@ -8,4 +8,5 @@
88
'optim',
99
'pyro',
1010
'pyro_backend',
11+
'register_backend',
1112
]

pyroapi/dispatch.py

Lines changed: 93 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,41 @@
1+
"""
2+
Dispatching allows you to dynamically set a backend using :func:`pyro_backend`
3+
and to register new backends using :func:`register_backend` . It's easiest to
4+
see how to use these by example:
5+
6+
.. code-block:: python
7+
8+
from pyroapi import distributions as dist
9+
from pyroapi import infer, ops, optim, pyro, pyro_backend
10+
11+
# These model and guide are backend-agnostic.
12+
def model():
13+
locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5]))
14+
p = ops.tensor([0.2, 0.3, 0.5])
15+
with pyro.plate("plate", len(data), dim=-1):
16+
x = pyro.sample("x", dist.Categorical(p))
17+
pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
18+
19+
def guide():
20+
p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2]))
21+
with pyro.plate("plate", len(data), dim=-1):
22+
pyro.sample("x", dist.Categorical(p))
23+
24+
# We can now set a backend at inference time.
25+
with pyro_backend("numpyro"):
26+
elbo = infer.Trace_ELBO(ignore_jit_warnings=True)
27+
adam = optim.Adam({"lr": 1e-6})
28+
inference = infer.SVI(model, guide, adam, elbo)
29+
for step in range(10):
30+
loss = inference.step(*args, **kwargs)
31+
print(f"step {step} loss = {loss}")
32+
33+
"""
134
import importlib
235
from contextlib import contextmanager
336

437
DEFAULT_RNG_SEED = 1
38+
_ALIASES = {}
539

640

741
class GenericModule(object):
@@ -38,9 +72,10 @@ def pyro_backend(*aliases, **new_backends):
3872
"""
3973
Context manager to set a custom backend for Pyro models.
4074
41-
Backends can be specified either by name (for standard backends)
42-
or by providing a dict mapping module name to backend module name.
43-
Standard backends include: pyro, minipyro, funsor, and numpy.
75+
Backends can be specified either by name (for standard backends or backends
76+
registered through :func:`register_backend` ) or by providing a dict
77+
mapping module name to backend module name. Standard backends include:
78+
pyro, minipyro, funsor, and numpy.
4479
"""
4580
if aliases:
4681
assert len(aliases) == 1
@@ -59,40 +94,26 @@ def pyro_backend(*aliases, **new_backends):
5994
GenericModule.current_backend[name] = old_backend
6095

6196

62-
_ALIASES = {
63-
'pyro': {
64-
'distributions': 'pyro.distributions',
65-
'handlers': 'pyro.poutine',
66-
'infer': 'pyro.infer',
67-
'ops': 'torch',
68-
'optim': 'pyro.optim',
69-
'pyro': 'pyro',
70-
},
71-
'minipyro': {
72-
'distributions': 'pyro.distributions',
73-
'handlers': 'pyro.poutine',
74-
'infer': 'pyro.contrib.minipyro',
75-
'ops': 'torch',
76-
'optim': 'pyro.contrib.minipyro',
77-
'pyro': 'pyro.contrib.minipyro',
78-
},
79-
'funsor': {
80-
'distributions': 'funsor.distributions',
81-
'handlers': 'funsor.minipyro',
82-
'infer': 'funsor.minipyro',
83-
'ops': 'funsor.compat.ops',
84-
'optim': 'funsor.minipyro',
85-
'pyro': 'funsor.minipyro',
86-
},
87-
'numpy': {
88-
'distributions': 'numpyro.compat.distributions',
89-
'handlers': 'numpyro.compat.handlers',
90-
'infer': 'numpyro.compat.infer',
91-
'ops': 'numpyro.compat.ops',
92-
'optim': 'numpyro.compat.optim',
93-
'pyro': 'numpyro.compat.pyro',
94-
},
95-
}
97+
def register_backend(alias, new_backends):
98+
"""
99+
Register a new backend alias. For example::
100+
101+
register_backend("minipyro", {
102+
"infer": "pyro.contrib.minipyro",
103+
"optim": "pyro.contrib.minipyro",
104+
"pyro": "pyro.contrib.minipyro",
105+
})
106+
107+
:param str alias: The name of the new backend.
108+
:param dict new_backends: A dict mapping standard module name (str) to new
109+
module name (str). This needs to include only nonstandard backends
110+
(e.g. if your backend uses torch ops, you need not override ``ops``)
111+
"""
112+
assert isinstance(new_backends, dict)
113+
assert all(isinstance(key, str) for key in new_backends.keys())
114+
assert all(isinstance(value, str) for value in new_backends.values())
115+
_ALIASES[alias] = new_backends.copy()
116+
96117

97118
# These modules can be overridden.
98119
pyro = GenericModule('pyro', 'pyro')
@@ -101,3 +122,38 @@ def pyro_backend(*aliases, **new_backends):
101122
infer = GenericModule('infer', 'pyro.infer')
102123
optim = GenericModule('optim', 'pyro.optim')
103124
ops = GenericModule('ops', 'torch')
125+
126+
127+
# These are standard backends.
128+
register_backend('pyro', {
129+
'distributions': 'pyro.distributions',
130+
'handlers': 'pyro.poutine',
131+
'infer': 'pyro.infer',
132+
'ops': 'torch',
133+
'optim': 'pyro.optim',
134+
'pyro': 'pyro',
135+
})
136+
register_backend('minipyro', {
137+
'distributions': 'pyro.distributions',
138+
'handlers': 'pyro.poutine',
139+
'infer': 'pyro.contrib.minipyro',
140+
'ops': 'torch',
141+
'optim': 'pyro.contrib.minipyro',
142+
'pyro': 'pyro.contrib.minipyro',
143+
})
144+
register_backend('funsor', {
145+
'distributions': 'funsor.distributions',
146+
'handlers': 'funsor.minipyro',
147+
'infer': 'funsor.minipyro',
148+
'ops': 'funsor.compat.ops',
149+
'optim': 'funsor.minipyro',
150+
'pyro': 'funsor.minipyro',
151+
})
152+
register_backend('numpy', {
153+
'distributions': 'numpyro.compat.distributions',
154+
'handlers': 'numpyro.compat.handlers',
155+
'infer': 'numpyro.compat.infer',
156+
'ops': 'numpyro.compat.ops',
157+
'optim': 'numpyro.compat.optim',
158+
'pyro': 'numpyro.compat.pyro',
159+
})

test/test_dispatch.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from pyroapi import handlers, infer, pyro, pyro_backend
3+
from pyroapi import handlers, infer, pyro, pyro_backend, register_backend
44
from pyroapi.testing import MODELS
55

66

@@ -37,11 +37,28 @@ def test_model_sample(model, backend):
3737

3838

3939
@pytest.mark.parametrize('model', MODELS)
40-
@pytest.mark.parametrize('backend', ['funsor', 'minipyro', 'numpy', 'pyro'])
41-
@pytest.mark.xfail(reason='Not supported by backend.')
40+
@pytest.mark.parametrize('backend', [
41+
pytest.param("funsor", marks=[pytest.mark.xfail(reason="not implemented")]),
42+
'minipyro',
43+
'numpy',
44+
'pyro',
45+
])
4246
def test_trace_handler(model, backend):
4347
with pyro_backend(backend), handlers.seed(rng_seed=2):
4448
f = MODELS[model]()
4549
model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
4650
# should be implemented
4751
handlers.trace(model).get_trace(*model_args, **model_kwargs)
52+
53+
54+
@pytest.mark.parametrize('model', MODELS)
55+
def test_register_backend(model):
56+
register_backend("foo", {
57+
"infer": "pyro.contrib.minipyro",
58+
"optim": "pyro.contrib.minipyro",
59+
"pyro": "pyro.contrib.minipyro",
60+
})
61+
with pyro_backend("foo"):
62+
f = MODELS[model]()
63+
model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
64+
handlers.trace(model).get_trace(*model_args, **model_kwargs)

0 commit comments

Comments
 (0)