Skip to content

Commit b895042

Browse files
authored
Add support for Optax optimizers (#1034)
* Add support for Optax optimizers * Test Optax * Fix isort linter * Add optax to docs requirements * Simplify optax wrapper * Better handling of import error * Pin Sphinx version * Address comments * Add optax_to_numpyro to docs * isort * Edit optax snippet in SVI docstring * Update docs * Use optax.OptState for type hints * Allow jax.experimental.optimizers.Optimizer instances to be used directly in SVI * Rerun CI * Pin Jinja2 version
1 parent 673880c commit b895042

File tree

8 files changed

+215
-10
lines changed

8 files changed

+215
-10
lines changed

docs/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ flax
33
funsor
44
jax>=0.1.65
55
jaxlib>=0.1.45
6-
nbsphinx==0.8.1
6+
optax==0.0.6
7+
nbsphinx>=0.8.4
78
sphinx-gallery
89
tfp-nightly # TODO: change this to tensorflow-probability when it is stable
910
tqdm

docs/source/optimizers.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,8 @@ SM3
6464
.. autoclass:: numpyro.optim.SM3
6565
:members:
6666
:undoc-members:
67-
:inherited-members:
67+
:inherited-members:
68+
69+
Optax support
70+
-------------
71+
.. autofunction:: numpyro.contrib.optim.optax_to_numpyro

numpyro/contrib/optim.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
This module provides a wrapper for Optax optimizers so that they can be used with
6+
NumPyro inference algorithms.
7+
"""
8+
9+
from typing import Tuple, TypeVar
10+
11+
import optax
12+
13+
from numpyro.optim import _NumPyroOptim
14+
15+
_Params = TypeVar("_Params")
16+
_State = Tuple[_Params, optax.OptState]
17+
18+
19+
def optax_to_numpyro(transformation: optax.GradientTransformation) -> _NumPyroOptim:
20+
"""
21+
This function produces a ``numpyro.optim._NumPyroOptim`` instance from an
22+
``optax.GradientTransformation`` so that it can be used with
23+
``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the
24+
``(init_fn, update_fn, get_params_fn)`` interface defined by
25+
:mod:`jax.experimental.optimizers`.
26+
27+
:param transformation: An ``optax.GradientTransformation`` instance to wrap.
28+
:return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied
29+
Optax optimizer.
30+
"""
31+
32+
def init_fn(params: _Params) -> _State:
33+
opt_state = transformation.init(params)
34+
return params, opt_state
35+
36+
def update_fn(step, grads: _Params, state: _State) -> _State:
37+
params, opt_state = state
38+
updates, opt_state = transformation.update(grads, opt_state, params)
39+
updated_params = optax.apply_updates(params, updates)
40+
return updated_params, opt_state
41+
42+
def get_params_fn(state: _State) -> _Params:
43+
params, _ = state
44+
return params
45+
46+
return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)

numpyro/infer/svi.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import tqdm
88

9+
import jax
910
from jax import jit, lax, random
1011
import jax.numpy as jnp
1112
from jax.tree_util import tree_map
@@ -14,6 +15,7 @@
1415
from numpyro.distributions.transforms import biject_to
1516
from numpyro.handlers import replay, seed, trace
1617
from numpyro.infer.util import transform_fn
18+
from numpyro.optim import _NumPyroOptim
1719

1820
SVIState = namedtuple("SVIState", ["optim_state", "rng_key"])
1921
"""
@@ -80,7 +82,14 @@ class SVI(object):
8082
:param model: Python callable with Pyro primitives for the model.
8183
:param guide: Python callable with Pyro primitives for the guide
8284
(recognition network).
83-
:param optim: an instance of :class:`~numpyro.optim._NumpyroOptim`.
85+
:param optim: An instance of :class:`~numpyro.optim._NumpyroOptim`, a
86+
``jax.experimental.optimizers.Optimizer`` or an Optax
87+
``GradientTransformation``. If you pass an Optax optimizer it will
88+
automatically be wrapped using :func:`numpyro.contrib.optim.optax_to_numpyro`.
89+
90+
>>> from optax import adam, chain, clip
91+
>>> svi = SVI(model, guide, chain(clip(10.0), adam(1e-3)), loss=Trace_ELBO())
92+
8493
:param loss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
8594
:param static_kwargs: static arguments for the model / guide, i.e. arguments
8695
that remain constant during fitting.
@@ -91,10 +100,36 @@ def __init__(self, model, guide, optim, loss, **static_kwargs):
91100
self.model = model
92101
self.guide = guide
93102
self.loss = loss
94-
self.optim = optim
95103
self.static_kwargs = static_kwargs
96104
self.constrain_fn = None
97105

106+
if isinstance(optim, _NumPyroOptim):
107+
self.optim = optim
108+
elif isinstance(optim, jax.experimental.optimizers.Optimizer):
109+
self.optim = _NumPyroOptim(lambda *args: args, *optim)
110+
else:
111+
try:
112+
import optax
113+
114+
from numpyro.contrib.optim import optax_to_numpyro
115+
except ImportError:
116+
raise ImportError(
117+
"It looks like you tried to use an optimizer that isn't an "
118+
"instance of numpyro.optim._NumPyroOptim or "
119+
"jax.experimental.optimizers.Optimizer. There is experimental "
120+
"support for Optax optimizers, but you need to install Optax. "
121+
"It can be installed with `pip install optax`."
122+
)
123+
124+
if not isinstance(optim, optax.GradientTransformation):
125+
raise TypeError(
126+
"Expected either an instance of numpyro.optim._NumPyroOptim, "
127+
"jax.experimental.optimizers.Optimizer or "
128+
"optax.GradientTransformation. Got {}".format(type(optim))
129+
)
130+
131+
self.optim = optax_to_numpyro(optim)
132+
98133
def init(self, rng_key, *args, **kwargs):
99134
"""
100135
Gets the initial SVI state.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ profile = black
88
skip_glob = .ipynb_checkpoints
99
known_first_party = funsor, numpyro, test
1010
known_third_party = opt_einsum
11-
known_jax = flax, haiku, jax, tensorflow_probability
11+
known_jax = flax, haiku, jax, optax, tensorflow_probability
1212
sections = FUTURE, STDLIB, THIRDPARTY, JAX, FIRSTPARTY, LOCALFOLDER
1313
force_sort_within_sections = true
1414
combine_as_imports = true

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
extras_require={
4040
"doc": [
4141
"ipython", # sphinx needs this to render codes
42+
"jinja2<3.0.0",
4243
"nbsphinx",
43-
"sphinx",
44+
"sphinx<4.0.0",
4445
"sphinx_rtd_theme",
4546
"sphinx-gallery",
4647
],
@@ -58,6 +59,7 @@
5859
# TODO: bump funsor version before the release
5960
"funsor @ git+https://github.com/pyro-ppl/funsor.git@d5574988665dd822ec64e41f2b54b9dc929959dc",
6061
"graphviz",
62+
"optax==0.0.6",
6163
# TODO: change this to tensorflow_probability>0.12.1 when the next version
6264
# of tfp is released. The current release is not compatible with jax>=0.2.12.
6365
"tfp-nightly",

test/contrib/test_optim.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright Contributors to the Pyro project.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from numpy.testing import assert_allclose
5+
import pytest
6+
7+
from jax import grad, jit, partial, random
8+
from jax.lax import fori_loop
9+
import jax.numpy as jnp
10+
from jax.test_util import check_close
11+
12+
import numpyro
13+
import numpyro.distributions as dist
14+
from numpyro.distributions import constraints
15+
from numpyro.infer import SVI, RenyiELBO, Trace_ELBO
16+
17+
try:
18+
import optax
19+
20+
from numpyro.contrib.optim import optax_to_numpyro
21+
22+
# the optimizer test is parameterized by different optax optimizers, but we have
23+
# to define them here to ensure that `optax` is defined. pytest.mark.parameterize
24+
# decorators are run even if tests are skipped at the top of the file.
25+
optimizers = [
26+
(optax.adam, (1e-2,), {}),
27+
# clipped adam
28+
(optax.chain, (optax.clip(10.0), optax.adam(1e-2)), {}),
29+
(optax.adagrad, (1e-1,), {}),
30+
# SGD with momentum
31+
(optax.sgd, (1e-2,), {"momentum": 0.9}),
32+
(optax.rmsprop, (1e-2,), {"decay": 0.95}),
33+
# RMSProp with momentum
34+
(optax.rmsprop, (1e-4,), {"decay": 0.9, "momentum": 0.9}),
35+
(optax.sgd, (1e-2,), {}),
36+
]
37+
except ImportError:
38+
pytestmark = pytest.mark.skip(reason="optax is not installed")
39+
optimizers = []
40+
41+
42+
def loss(params):
43+
return jnp.sum(params["x"] ** 2 + params["y"] ** 2)
44+
45+
46+
@partial(jit, static_argnums=(1,))
47+
def step(opt_state, optim):
48+
params = optim.get_params(opt_state)
49+
g = grad(loss)(params)
50+
return optim.update(g, opt_state)
51+
52+
53+
@pytest.mark.parametrize("optim_class, args, kwargs", optimizers)
54+
def test_optim_multi_params(optim_class, args, kwargs):
55+
params = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([-1, -1.0, -1.0])}
56+
opt = optax_to_numpyro(optim_class(*args, **kwargs))
57+
opt_state = opt.init(params)
58+
for i in range(2000):
59+
opt_state = step(opt_state, opt)
60+
for _, param in opt.get_params(opt_state).items():
61+
assert jnp.allclose(param, jnp.zeros(3))
62+
63+
64+
@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
65+
def test_beta_bernoulli(elbo):
66+
data = jnp.array([1.0] * 8 + [0.0] * 2)
67+
68+
def model(data):
69+
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
70+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
71+
72+
def guide(data):
73+
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
74+
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
75+
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
76+
77+
adam = optax.adam(0.05)
78+
svi = SVI(model, guide, adam, elbo)
79+
svi_state = svi.init(random.PRNGKey(1), data)
80+
assert_allclose(svi.optim.get_params(svi_state.optim_state)["alpha_q"], 0.0)
81+
82+
def body_fn(i, val):
83+
svi_state, _ = svi.update(val, data)
84+
return svi_state
85+
86+
svi_state = fori_loop(0, 2000, body_fn, svi_state)
87+
params = svi.get_params(svi_state)
88+
assert_allclose(
89+
params["alpha_q"] / (params["alpha_q"] + params["beta_q"]),
90+
0.8,
91+
atol=0.05,
92+
rtol=0.05,
93+
)
94+
95+
96+
def test_jitted_update_fn():
97+
data = jnp.array([1.0] * 8 + [0.0] * 2)
98+
99+
def model(data):
100+
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
101+
numpyro.sample("obs", dist.Bernoulli(f), obs=data)
102+
103+
def guide(data):
104+
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
105+
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
106+
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
107+
108+
adam = optax.adam(0.05)
109+
svi = SVI(model, guide, adam, Trace_ELBO())
110+
svi_state = svi.init(random.PRNGKey(1), data)
111+
expected = svi.get_params(svi.update(svi_state, data)[0])
112+
113+
actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
114+
check_close(actual, expected, atol=1e-5)

test/infer/test_svi.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numpy.testing import assert_allclose
55
import pytest
66

7+
import jax
78
from jax import jit, random, value_and_grad
89
import jax.numpy as jnp
910
from jax.test_util import check_close
@@ -41,7 +42,10 @@ def renyi_loss_fn(x):
4142

4243

4344
@pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)])
44-
def test_beta_bernoulli(elbo):
45+
@pytest.mark.parametrize(
46+
"optimizer", [optim.Adam(0.05), jax.experimental.optimizers.adam(0.05)]
47+
)
48+
def test_beta_bernoulli(elbo, optimizer):
4549
data = jnp.array([1.0] * 8 + [0.0] * 2)
4650

4751
def model(data):
@@ -53,10 +57,9 @@ def guide(data):
5357
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
5458
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
5559

56-
adam = optim.Adam(0.05)
57-
svi = SVI(model, guide, adam, elbo)
60+
svi = SVI(model, guide, optimizer, elbo)
5861
svi_state = svi.init(random.PRNGKey(1), data)
59-
assert_allclose(adam.get_params(svi_state.optim_state)["alpha_q"], 0.0)
62+
assert_allclose(svi.optim.get_params(svi_state.optim_state)["alpha_q"], 0.0)
6063

6164
def body_fn(i, val):
6265
svi_state, _ = svi.update(val, data)

0 commit comments

Comments
 (0)