1+ """
2+ Dispatching allows you to dynamically set a backend using :func:`pyro_backend`
3+ and to register new backends using :func:`register_backend` . It's easiest to
4+ see how to use these by example:
5+
6+ .. code-block:: python
7+
8+ from pyroapi import distributions as dist
9+ from pyroapi import infer, ops, optim, pyro, pyro_backend
10+
11+ # These model and guide are backend-agnostic.
12+ def model():
13+ locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5]))
14+ p = ops.tensor([0.2, 0.3, 0.5])
15+ with pyro.plate("plate", len(data), dim=-1):
16+ x = pyro.sample("x", dist.Categorical(p))
17+ pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
18+
19+ def guide():
20+ p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2]))
21+ with pyro.plate("plate", len(data), dim=-1):
22+ pyro.sample("x", dist.Categorical(p))
23+
24+ # We can now set a backend at inference time.
25+ with pyro_backend("numpyro"):
26+ elbo = infer.Trace_ELBO(ignore_jit_warnings=True)
27+ adam = optim.Adam({"lr": 1e-6})
28+ inference = infer.SVI(model, guide, adam, elbo)
29+ for step in range(10):
30+ loss = inference.step(*args, **kwargs)
31+ print(f"step {step} loss = {loss}")
32+
33+ """
134import importlib
235from contextlib import contextmanager
336
437DEFAULT_RNG_SEED = 1
38+ _ALIASES = {}
539
640
741class GenericModule (object ):
@@ -38,9 +72,10 @@ def pyro_backend(*aliases, **new_backends):
3872 """
3973 Context manager to set a custom backend for Pyro models.
4074
41- Backends can be specified either by name (for standard backends)
42- or by providing a dict mapping module name to backend module name.
43- Standard backends include: pyro, minipyro, funsor, and numpy.
75+ Backends can be specified either by name (for standard backends or backends
76+ registered through :func:`register_backend` ) or by providing a dict
77+ mapping module name to backend module name. Standard backends include:
78+ pyro, minipyro, funsor, and numpy.
4479 """
4580 if aliases :
4681 assert len (aliases ) == 1
@@ -59,40 +94,26 @@ def pyro_backend(*aliases, **new_backends):
5994 GenericModule .current_backend [name ] = old_backend
6095
6196
62- _ALIASES = {
63- 'pyro' : {
64- 'distributions' : 'pyro.distributions' ,
65- 'handlers' : 'pyro.poutine' ,
66- 'infer' : 'pyro.infer' ,
67- 'ops' : 'torch' ,
68- 'optim' : 'pyro.optim' ,
69- 'pyro' : 'pyro' ,
70- },
71- 'minipyro' : {
72- 'distributions' : 'pyro.distributions' ,
73- 'handlers' : 'pyro.poutine' ,
74- 'infer' : 'pyro.contrib.minipyro' ,
75- 'ops' : 'torch' ,
76- 'optim' : 'pyro.contrib.minipyro' ,
77- 'pyro' : 'pyro.contrib.minipyro' ,
78- },
79- 'funsor' : {
80- 'distributions' : 'funsor.distributions' ,
81- 'handlers' : 'funsor.minipyro' ,
82- 'infer' : 'funsor.minipyro' ,
83- 'ops' : 'funsor.compat.ops' ,
84- 'optim' : 'funsor.minipyro' ,
85- 'pyro' : 'funsor.minipyro' ,
86- },
87- 'numpy' : {
88- 'distributions' : 'numpyro.compat.distributions' ,
89- 'handlers' : 'numpyro.compat.handlers' ,
90- 'infer' : 'numpyro.compat.infer' ,
91- 'ops' : 'numpyro.compat.ops' ,
92- 'optim' : 'numpyro.compat.optim' ,
93- 'pyro' : 'numpyro.compat.pyro' ,
94- },
95- }
97+ def register_backend (alias , new_backends ):
98+ """
99+ Register a new backend alias. For example::
100+
101+ register_backend("minipyro", {
102+ "infer": "pyro.contrib.minipyro",
103+ "optim": "pyro.contrib.minipyro",
104+ "pyro": "pyro.contrib.minipyro",
105+ })
106+
107+ :param str alias: The name of the new backend.
108+ :param dict new_backends: A dict mapping standard module name (str) to new
109+ module name (str). This needs to include only nonstandard backends
110+ (e.g. if your backend uses torch ops, you need not override ``ops``)
111+ """
112+ assert isinstance (new_backends , dict )
113+ assert all (isinstance (key , str ) for key in new_backends .keys ())
114+ assert all (isinstance (value , str ) for value in new_backends .values ())
115+ _ALIASES [alias ] = new_backends .copy ()
116+
96117
97118# These modules can be overridden.
98119pyro = GenericModule ('pyro' , 'pyro' )
@@ -101,3 +122,38 @@ def pyro_backend(*aliases, **new_backends):
101122infer = GenericModule ('infer' , 'pyro.infer' )
102123optim = GenericModule ('optim' , 'pyro.optim' )
103124ops = GenericModule ('ops' , 'torch' )
125+
126+
127+ # These are standard backends.
128+ register_backend ('pyro' , {
129+ 'distributions' : 'pyro.distributions' ,
130+ 'handlers' : 'pyro.poutine' ,
131+ 'infer' : 'pyro.infer' ,
132+ 'ops' : 'torch' ,
133+ 'optim' : 'pyro.optim' ,
134+ 'pyro' : 'pyro' ,
135+ })
136+ register_backend ('minipyro' , {
137+ 'distributions' : 'pyro.distributions' ,
138+ 'handlers' : 'pyro.poutine' ,
139+ 'infer' : 'pyro.contrib.minipyro' ,
140+ 'ops' : 'torch' ,
141+ 'optim' : 'pyro.contrib.minipyro' ,
142+ 'pyro' : 'pyro.contrib.minipyro' ,
143+ })
144+ register_backend ('funsor' , {
145+ 'distributions' : 'funsor.distributions' ,
146+ 'handlers' : 'funsor.minipyro' ,
147+ 'infer' : 'funsor.minipyro' ,
148+ 'ops' : 'funsor.compat.ops' ,
149+ 'optim' : 'funsor.minipyro' ,
150+ 'pyro' : 'funsor.minipyro' ,
151+ })
152+ register_backend ('numpy' , {
153+ 'distributions' : 'numpyro.compat.distributions' ,
154+ 'handlers' : 'numpyro.compat.handlers' ,
155+ 'infer' : 'numpyro.compat.infer' ,
156+ 'ops' : 'numpyro.compat.ops' ,
157+ 'optim' : 'numpyro.compat.optim' ,
158+ 'pyro' : 'numpyro.compat.pyro' ,
159+ })
0 commit comments