|
| 1 | +# Copyright Contributors to the Pyro project. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +from functools import partial |
| 5 | + |
| 6 | +from jax import device_put, lax |
| 7 | + |
| 8 | +from numpyro import handlers |
| 9 | +from numpyro.contrib.control_flow.util import PytreeTrace |
| 10 | +from numpyro.primitives import _PYRO_STACK, apply_stack |
| 11 | + |
| 12 | + |
| 13 | +def _subs_wrapper(subs_map, site): |
| 14 | + if isinstance(subs_map, dict) and site["name"] in subs_map: |
| 15 | + return subs_map[site["name"]] |
| 16 | + elif callable(subs_map): |
| 17 | + rng_key = site["kwargs"].get("rng_key") |
| 18 | + subs_map = ( |
| 19 | + handlers.seed(subs_map, rng_seed=rng_key) |
| 20 | + if rng_key is not None |
| 21 | + else subs_map |
| 22 | + ) |
| 23 | + return subs_map(site) |
| 24 | + return None |
| 25 | + |
| 26 | + |
| 27 | +def wrap_fn(fn, substitute_stack): |
| 28 | + def wrapper(wrapped_operand): |
| 29 | + rng_key, operand = wrapped_operand |
| 30 | + |
| 31 | + with handlers.block(): |
| 32 | + seeded_fn = handlers.seed(fn, rng_key) if rng_key is not None else fn |
| 33 | + for subs_type, subs_map in substitute_stack: |
| 34 | + subs_fn = partial(_subs_wrapper, subs_map) |
| 35 | + if subs_type == "condition": |
| 36 | + seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) |
| 37 | + elif subs_type == "substitute": |
| 38 | + seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) |
| 39 | + |
| 40 | + with handlers.trace() as trace: |
| 41 | + value = seeded_fn(operand) |
| 42 | + |
| 43 | + return value, PytreeTrace(trace) |
| 44 | + |
| 45 | + return wrapper |
| 46 | + |
| 47 | + |
| 48 | +def cond_wrapper( |
| 49 | + pred, |
| 50 | + true_fun, |
| 51 | + false_fun, |
| 52 | + operand, |
| 53 | + rng_key=None, |
| 54 | + substitute_stack=None, |
| 55 | + enum=False, |
| 56 | + first_available_dim=None, |
| 57 | +): |
| 58 | + if enum: |
| 59 | + # TODO: support enumeration. note that pred passed to lax.cond must be scalar |
| 60 | + # which means that even simple conditions like `x == 0` can get complicated if |
| 61 | + # x is an enumerated discrete random variable |
| 62 | + raise RuntimeError("The cond primitive does not currently support enumeration") |
| 63 | + |
| 64 | + if substitute_stack is None: |
| 65 | + substitute_stack = [] |
| 66 | + |
| 67 | + wrapped_true_fun = wrap_fn(true_fun, substitute_stack) |
| 68 | + wrapped_false_fun = wrap_fn(false_fun, substitute_stack) |
| 69 | + wrapped_operand = device_put((rng_key, operand)) |
| 70 | + return lax.cond(pred, wrapped_true_fun, wrapped_false_fun, wrapped_operand) |
| 71 | + |
| 72 | + |
| 73 | +def cond(pred, true_fun, false_fun, operand): |
| 74 | + """ |
| 75 | + This primitive conditionally applies ``true_fun`` or ``false_fun``. See |
| 76 | + :func:`jax.lax.cond` for more information. |
| 77 | +
|
| 78 | + **Usage**: |
| 79 | +
|
| 80 | + .. doctest:: |
| 81 | +
|
| 82 | + >>> import numpyro |
| 83 | + >>> import numpyro.distributions as dist |
| 84 | + >>> from jax import random |
| 85 | + >>> from numpyro.contrib.control_flow import cond |
| 86 | + >>> from numpyro.infer import SVI, Trace_ELBO |
| 87 | + >>> |
| 88 | + >>> def model(): |
| 89 | + ... def true_fun(_): |
| 90 | + ... return numpyro.sample("x", dist.Normal(20.0)) |
| 91 | + ... |
| 92 | + ... def false_fun(_): |
| 93 | + ... return numpyro.sample("x", dist.Normal(0.0)) |
| 94 | + ... |
| 95 | + ... cluster = numpyro.sample("cluster", dist.Normal()) |
| 96 | + ... return cond(cluster > 0, true_fun, false_fun, None) |
| 97 | + >>> |
| 98 | + >>> def guide(): |
| 99 | + ... m1 = numpyro.param("m1", 10.0) |
| 100 | + ... s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) |
| 101 | + ... m2 = numpyro.param("m2", 10.0) |
| 102 | + ... s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) |
| 103 | + ... |
| 104 | + ... def true_fun(_): |
| 105 | + ... return numpyro.sample("x", dist.Normal(m1, s1)) |
| 106 | + ... |
| 107 | + ... def false_fun(_): |
| 108 | + ... return numpyro.sample("x", dist.Normal(m2, s2)) |
| 109 | + ... |
| 110 | + ... cluster = numpyro.sample("cluster", dist.Normal()) |
| 111 | + ... return cond(cluster > 0, true_fun, false_fun, None) |
| 112 | + >>> |
| 113 | + >>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100)) |
| 114 | + >>> params, losses = svi.run(random.PRNGKey(0), num_steps=2500) |
| 115 | +
|
| 116 | + .. warning:: This is an experimental utility function that allows users to use |
| 117 | + JAX control flow with NumPyro's effect handlers. Currently, `sample` and |
| 118 | + `deterministic` sites within `true_fun` and `false_fun` are supported. If you |
| 119 | + notice that any effect handlers or distributions are unsupported, please file |
| 120 | + an issue. |
| 121 | +
|
| 122 | + .. warning:: The ``cond`` primitive does not currently support enumeration and can |
| 123 | + not be used inside a ``numpyro.plate`` context. |
| 124 | +
|
| 125 | + .. note:: All ``sample`` sites must belong to the same distribution class. For |
| 126 | + example the following is not supported |
| 127 | +
|
| 128 | + .. code-block:: python |
| 129 | +
|
| 130 | + cond( |
| 131 | + True, |
| 132 | + lambda _: numpyro.sample("x", dist.Normal()), |
| 133 | + lambda _: numpyro.sample("x", dist.Laplace()), |
| 134 | + None, |
| 135 | + ) |
| 136 | +
|
| 137 | + :param bool pred: Boolean scalar type indicating which branch function to apply |
| 138 | + :param callable true_fun: A function to be applied if ``pred`` is true. |
| 139 | + :param callable false_fun: A function to be applied if ``pred`` is false. |
| 140 | + :param operand: Operand input to either branch depending on ``pred``. This can |
| 141 | + be any JAX PyTree (e.g. list / dict of arrays). |
| 142 | + :return: Output of the applied branch function. |
| 143 | + """ |
| 144 | + if not _PYRO_STACK: |
| 145 | + value, _ = cond_wrapper(pred, true_fun, false_fun, operand) |
| 146 | + return value |
| 147 | + |
| 148 | + initial_msg = { |
| 149 | + "type": "control_flow", |
| 150 | + "fn": cond_wrapper, |
| 151 | + "args": (pred, true_fun, false_fun, operand), |
| 152 | + "kwargs": {"rng_key": None, "substitute_stack": []}, |
| 153 | + "value": None, |
| 154 | + } |
| 155 | + |
| 156 | + msg = apply_stack(initial_msg) |
| 157 | + value, pytree_trace = msg["value"] |
| 158 | + |
| 159 | + for msg in pytree_trace.trace.values(): |
| 160 | + apply_stack(msg) |
| 161 | + |
| 162 | + return value |
0 commit comments