Skip to content

Commit 1ca72ff

Browse files
authored
Add cond primitive (#1028)
* Add `cond` control flow primitive * Add test for cond * Make cond usage example more readable * Docstring fixes * Test NUTS with cond * Add return statements to example * Fix markup in docstring * Fix failing doctest * Add enumeration warning * Make MCMC + cond test more robust * Test full mean and std of mixture in cond test * Adjust mode separation in cond test * Typo in test >_<
1 parent 7d24d43 commit 1ca72ff

File tree

6 files changed

+285
-38
lines changed

6 files changed

+285
-38
lines changed

docs/source/primitives.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@ random_haiku_module
6262
scan
6363
----
6464
.. autofunction:: numpyro.contrib.control_flow.scan
65+
66+
cond
67+
----
68+
.. autofunction:: numpyro.contrib.control_flow.cond
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from numpyro.contrib.control_flow.cond import cond
45
from numpyro.contrib.control_flow.scan import scan
56

6-
__all__ = [
7-
"scan",
8-
]
7+
__all__ = ["cond", "scan"]
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from functools import partial
5+
6+
from jax import device_put, lax
7+
8+
from numpyro import handlers
9+
from numpyro.contrib.control_flow.util import PytreeTrace
10+
from numpyro.primitives import _PYRO_STACK, apply_stack
11+
12+
13+
def _subs_wrapper(subs_map, site):
14+
if isinstance(subs_map, dict) and site["name"] in subs_map:
15+
return subs_map[site["name"]]
16+
elif callable(subs_map):
17+
rng_key = site["kwargs"].get("rng_key")
18+
subs_map = (
19+
handlers.seed(subs_map, rng_seed=rng_key)
20+
if rng_key is not None
21+
else subs_map
22+
)
23+
return subs_map(site)
24+
return None
25+
26+
27+
def wrap_fn(fn, substitute_stack):
28+
def wrapper(wrapped_operand):
29+
rng_key, operand = wrapped_operand
30+
31+
with handlers.block():
32+
seeded_fn = handlers.seed(fn, rng_key) if rng_key is not None else fn
33+
for subs_type, subs_map in substitute_stack:
34+
subs_fn = partial(_subs_wrapper, subs_map)
35+
if subs_type == "condition":
36+
seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
37+
elif subs_type == "substitute":
38+
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)
39+
40+
with handlers.trace() as trace:
41+
value = seeded_fn(operand)
42+
43+
return value, PytreeTrace(trace)
44+
45+
return wrapper
46+
47+
48+
def cond_wrapper(
49+
pred,
50+
true_fun,
51+
false_fun,
52+
operand,
53+
rng_key=None,
54+
substitute_stack=None,
55+
enum=False,
56+
first_available_dim=None,
57+
):
58+
if enum:
59+
# TODO: support enumeration. note that pred passed to lax.cond must be scalar
60+
# which means that even simple conditions like `x == 0` can get complicated if
61+
# x is an enumerated discrete random variable
62+
raise RuntimeError("The cond primitive does not currently support enumeration")
63+
64+
if substitute_stack is None:
65+
substitute_stack = []
66+
67+
wrapped_true_fun = wrap_fn(true_fun, substitute_stack)
68+
wrapped_false_fun = wrap_fn(false_fun, substitute_stack)
69+
wrapped_operand = device_put((rng_key, operand))
70+
return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand)
71+
72+
73+
def cond(pred, true_fun, false_fun, operand):
74+
"""
75+
This primitive conditionally applies ``true_fun`` or ``false_fun``. See
76+
:func:`jax.lax.cond` for more information.
77+
78+
**Usage**:
79+
80+
.. doctest::
81+
82+
>>> import numpyro
83+
>>> import numpyro.distributions as dist
84+
>>> from jax import random
85+
>>> from numpyro.contrib.control_flow import cond
86+
>>> from numpyro.infer import SVI, Trace_ELBO
87+
>>>
88+
>>> def model():
89+
... def true_fun(_):
90+
... return numpyro.sample("x", dist.Normal(20.0))
91+
...
92+
... def false_fun(_):
93+
... return numpyro.sample("x", dist.Normal(0.0))
94+
...
95+
... cluster = numpyro.sample("cluster", dist.Normal())
96+
... return cond(cluster > 0, true_fun, false_fun, None)
97+
>>>
98+
>>> def guide():
99+
... m1 = numpyro.param("m1", 10.0)
100+
... s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive)
101+
... m2 = numpyro.param("m2", 10.0)
102+
... s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive)
103+
...
104+
... def true_fun(_):
105+
... return numpyro.sample("x", dist.Normal(m1, s1))
106+
...
107+
... def false_fun(_):
108+
... return numpyro.sample("x", dist.Normal(m2, s2))
109+
...
110+
... cluster = numpyro.sample("cluster", dist.Normal())
111+
... return cond(cluster > 0, true_fun, false_fun, None)
112+
>>>
113+
>>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
114+
>>> params, losses = svi.run(random.PRNGKey(0), num_steps=2500)
115+
116+
.. warning:: This is an experimental utility function that allows users to use
117+
JAX control flow with NumPyro's effect handlers. Currently, `sample` and
118+
`deterministic` sites within `true_fun` and `false_fun` are supported. If you
119+
notice that any effect handlers or distributions are unsupported, please file
120+
an issue.
121+
122+
.. warning:: The ``cond`` primitive does not currently support enumeration and can
123+
not be used inside a ``numpyro.plate`` context.
124+
125+
.. note:: All ``sample`` sites must belong to the same distribution class. For
126+
example the following is not supported
127+
128+
.. code-block:: python
129+
130+
cond(
131+
True,
132+
lambda _: numpyro.sample("x", dist.Normal()),
133+
lambda _: numpyro.sample("x", dist.Laplace()),
134+
None,
135+
)
136+
137+
:param bool pred: Boolean scalar type indicating which branch function to apply
138+
:param callable true_fun: A function to be applied if ``pred`` is true.
139+
:param callable false_fun: A function to be applied if ``pred`` is false.
140+
:param operand: Operand input to either branch depending on ``pred``. This can
141+
be any JAX PyTree (e.g. list / dict of arrays).
142+
:return: Output of the applied branch function.
143+
"""
144+
if not _PYRO_STACK:
145+
value, _ = cond_wrapper(pred, true_fun, false_fun, operand)
146+
return value
147+
148+
initial_msg = {
149+
"type": "control_flow",
150+
"fn": cond_wrapper,
151+
"args": (pred, true_fun, false_fun, operand),
152+
"kwargs": {"rng_key": None, "substitute_stack": []},
153+
"value": None,
154+
}
155+
156+
msg = apply_stack(initial_msg)
157+
value, pytree_trace = msg["value"]
158+
159+
for msg in pytree_trace.trace.values():
160+
apply_stack(msg)
161+
162+
return value

numpyro/contrib/control_flow/scan.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,13 @@
1414
tree_unflatten,
1515
)
1616
import jax.numpy as jnp
17-
from jax.tree_util import register_pytree_node_class
1817

1918
from numpyro import handlers
19+
from numpyro.contrib.control_flow.util import PytreeTrace
2020
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack
2121
from numpyro.util import not_jax_tracer
2222

2323

24-
@register_pytree_node_class
25-
class PytreeTrace:
26-
def __init__(self, trace):
27-
self.trace = trace
28-
29-
def tree_flatten(self):
30-
trace, aux_trace = {}, {}
31-
for name, site in self.trace.items():
32-
if site["type"] in ["sample", "deterministic"]:
33-
trace[name], aux_trace[name] = {}, {"_control_flow_done": True}
34-
for key in site:
35-
if key in ["fn", "args", "value", "intermediates"]:
36-
trace[name][key] = site[key]
37-
# scanned sites have stop field because we trace them inside a block handler
38-
elif key != "stop":
39-
aux_trace[name][key] = site[key]
40-
# keep the site order information because in JAX, flatten and unflatten do not preserve
41-
# the order of keys in a dict
42-
site_names = list(trace.keys())
43-
return (trace,), (aux_trace, site_names)
44-
45-
@classmethod
46-
def tree_unflatten(cls, aux_data, children):
47-
aux_trace, site_names = aux_data
48-
(trace,) = children
49-
trace_with_aux = {}
50-
for name in site_names:
51-
trace[name].update(aux_trace[name])
52-
trace_with_aux[name] = trace[name]
53-
return cls(trace_with_aux)
54-
55-
5624
def _subs_wrapper(subs_map, i, length, site):
5725
value = None
5826
if isinstance(subs_map, dict) and site["name"] in subs_map:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from jax.tree_util import register_pytree_node_class
5+
6+
7+
@register_pytree_node_class
8+
class PytreeTrace:
9+
def __init__(self, trace):
10+
self.trace = trace
11+
12+
def tree_flatten(self):
13+
trace, aux_trace = {}, {}
14+
for name, site in self.trace.items():
15+
if site["type"] in ["sample", "deterministic"]:
16+
trace[name], aux_trace[name] = {}, {"_control_flow_done": True}
17+
for key in site:
18+
if key in ["fn", "args", "value", "intermediates"]:
19+
trace[name][key] = site[key]
20+
# scanned sites have stop field because we trace them inside a block handler
21+
elif key != "stop":
22+
if key == "kwargs":
23+
kwargs = site["kwargs"].copy()
24+
if "rng_key" in kwargs:
25+
# rng_key is not traced else it is collected by the
26+
# scan primitive which doesn't make sense
27+
# set to None to avoid leaks during tracing by JAX
28+
kwargs["rng_key"] = None
29+
aux_trace[name][key] = kwargs
30+
else:
31+
aux_trace[name][key] = site[key]
32+
# keep the site order information because in JAX, flatten and unflatten do not preserve
33+
# the order of keys in a dict
34+
site_names = list(trace.keys())
35+
return (trace,), (aux_trace, site_names)
36+
37+
@classmethod
38+
def tree_unflatten(cls, aux_data, children):
39+
aux_trace, site_names = aux_data
40+
(trace,) = children
41+
trace_with_aux = {}
42+
for name in site_names:
43+
trace[name].update(aux_trace[name])
44+
trace_with_aux[name] = trace[name]
45+
return cls(trace_with_aux)

test/contrib/test_control_flow.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import jax.numpy as jnp
99

1010
import numpyro
11-
from numpyro.contrib.control_flow.scan import scan
11+
from numpyro.contrib.control_flow import cond, scan
1212
import numpyro.distributions as dist
1313
from numpyro.handlers import seed, substitute, trace
14-
from numpyro.infer import MCMC, NUTS, Predictive
14+
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
1515
from numpyro.infer.util import potential_energy
1616

1717

@@ -132,3 +132,72 @@ def iteration(c_prev, c_in):
132132
result,
133133
[[1.7, 0.3]],
134134
)
135+
136+
137+
def test_cond():
138+
def model():
139+
def true_fun(_):
140+
x = numpyro.sample("x", dist.Normal(4.0))
141+
numpyro.deterministic("z", x - 4.0)
142+
143+
def false_fun(_):
144+
x = numpyro.sample("x", dist.Normal(0.0))
145+
numpyro.deterministic("z", x)
146+
147+
cluster = numpyro.sample("cluster", dist.Normal())
148+
cond(cluster > 0, true_fun, false_fun, None)
149+
150+
def guide():
151+
m1 = numpyro.param("m1", 2.0)
152+
s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive)
153+
m2 = numpyro.param("m2", 2.0)
154+
s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive)
155+
156+
def true_fun(_):
157+
numpyro.sample("x", dist.Normal(m1, s1))
158+
159+
def false_fun(_):
160+
numpyro.sample("x", dist.Normal(m2, s2))
161+
162+
cluster = numpyro.sample("cluster", dist.Normal())
163+
cond(cluster > 0, true_fun, false_fun, None)
164+
165+
svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
166+
params, losses = svi.run(random.PRNGKey(0), num_steps=2500)
167+
168+
predictive = Predictive(
169+
model,
170+
guide=guide,
171+
params=params,
172+
num_samples=1000,
173+
return_sites=["cluster", "x", "z"],
174+
)
175+
result = predictive(random.PRNGKey(0))
176+
177+
assert result["cluster"].shape == (1000,)
178+
assert result["x"].shape == (1000,)
179+
assert result["z"].shape == (1000,)
180+
181+
mcmc = MCMC(
182+
NUTS(model),
183+
num_warmup=500,
184+
num_samples=2500,
185+
num_chains=4,
186+
chain_method="sequential",
187+
)
188+
mcmc.run(random.PRNGKey(0))
189+
190+
x = mcmc.get_samples()["x"]
191+
assert x.shape == (10_000,)
192+
assert_allclose(
193+
[
194+
x.mean(),
195+
x.std(),
196+
x[x > 2.0].mean(),
197+
x[x > 2.0].std(),
198+
x[x < 2.0].mean(),
199+
x[x < 2.0].std(),
200+
],
201+
[2.0, jnp.sqrt(5.0), 4.0, 1.0, 0.0, 1.0],
202+
atol=0.1,
203+
)

0 commit comments

Comments
 (0)