Skip to content

Commit bdb3329

Browse files
authored
Pass value arg to optax, allowing use of reduce_on_plateau (#1974)
* Pass value arg to optax, allowing use of reduce_on_plateau * Address some PR comments * Simplify, improve typing * Address more PR comments * Updates from PR comments * Special-case the reduce on plateau scheduler for JIT test * Pass loss value in SteinVI
1 parent f5ae79b commit bdb3329

File tree

3 files changed

+96
-39
lines changed

3 files changed

+96
-39
lines changed

numpyro/contrib/einstein/steinvi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def update(self, state: SteinVIState, *args, **kwargs) -> SteinVIState:
467467
**kwargs,
468468
**self.static_kwargs,
469469
)
470-
optim_state = self.optim.update(grads, optim_state)
470+
optim_state = self.optim.update(grads, optim_state, value=loss_val)
471471
return SteinVIState(
472472
optim_state, rng_key, state.loss_temperature, state.repulsion_temperature
473473
), loss_val

numpyro/optim.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from collections import namedtuple
1111
from collections.abc import Callable
12-
from typing import Any
12+
from typing import Any, Optional, Protocol
1313

1414
import jax
1515
from jax import jacfwd, lax, value_and_grad
@@ -50,11 +50,28 @@ def _wrapper(x):
5050
return value_and_grad(f, has_aux=True)(x)
5151

5252

53+
class UpdateExtraArgsFn(Protocol):
54+
"""An update function accepting additional keyword arguments."""
55+
56+
def __call__(
57+
self,
58+
arr: ArrayLike,
59+
params: _Params,
60+
state: _OptState,
61+
**extra_args: Any,
62+
) -> _OptState:
63+
"""
64+
Based on https://github.com/google-deepmind/optax/blob/2e66ce897e83b4901d37dcbb477a7432497848d6/optax/_src/base.py#L110-L147,
65+
this protocol expresses an update function that *may* take extra arguments.
66+
"""
67+
68+
5369
class _NumPyroOptim(object):
5470
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
5571
self.init_fn: Callable[[_Params], _IterOptState]
56-
self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState]
72+
self.update_fn: UpdateExtraArgsFn
5773
self.get_params_fn: Callable[[_OptState], _Params]
74+
self.update_with_value: bool = kwargs.pop("update_with_value", False)
5875
self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs)
5976

6077
def init(self, params: _Params) -> _IterOptState:
@@ -67,7 +84,9 @@ def init(self, params: _Params) -> _IterOptState:
6784
opt_state = self.init_fn(params)
6885
return jnp.array(0), opt_state
6986

70-
def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
87+
def update(
88+
self, g: _Params, state: _IterOptState, value: Optional[ArrayLike] = None
89+
) -> _IterOptState:
7190
"""
7291
Gradient update for the optimizer.
7392
@@ -76,7 +95,11 @@ def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
7695
:return: new optimizer state after the update.
7796
"""
7897
i, opt_state = state
79-
opt_state = self.update_fn(i, g, opt_state)
98+
if self.update_with_value:
99+
assert value is not None
100+
opt_state = self.update_fn(i, g, opt_state, value=value)
101+
else:
102+
opt_state = self.update_fn(i, g, opt_state)
80103
return i + 1, opt_state
81104

82105
def eval_and_update(
@@ -104,7 +127,7 @@ def eval_and_update(
104127
(out, aux), grads = _value_and_grad(
105128
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
106129
)
107-
return (out, aux), self.update(grads, state)
130+
return (out, aux), self.update(grads, state, value=out)
108131

109132
def eval_and_stable_update(
110133
self,
@@ -128,7 +151,7 @@ def eval_and_stable_update(
128151
)
129152
out, state = lax.cond(
130153
jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(),
131-
lambda _: (out, self.update(grads, state)),
154+
lambda _: (out, self.update(grads, state, value=out)),
132155
lambda _: (jnp.nan, state),
133156
None,
134157
)
@@ -178,7 +201,9 @@ def __init__(self, *args, clip_norm: float = 10.0, **kwargs) -> None:
178201
self.clip_norm = clip_norm
179202
super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs)
180203

181-
def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
204+
def update(
205+
self, g: _Params, state: _IterOptState, value: Optional[ArrayLike] = None
206+
) -> _IterOptState:
182207
i, opt_state = state
183208
# clip norm
184209
g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g)
@@ -352,15 +377,26 @@ def init_fn(params: _Params) -> tuple[_Params, Any]:
352377
return params, opt_state
353378

354379
def update_fn(
355-
step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any]
380+
step: ArrayLike,
381+
grads: ArrayLike,
382+
state: tuple[_Params, Any],
383+
value: ArrayLike,
356384
) -> tuple[_Params, Any]:
357385
params, opt_state = state
358-
updates, opt_state = transformation.update(grads, opt_state, params)
386+
updates, opt_state = optax.with_extra_args_support(transformation).update(
387+
grads, opt_state, params, value=value
388+
)
359389
updated_params = optax.apply_updates(params, updates)
360390
return updated_params, opt_state
361391

362392
def get_params_fn(state: tuple[_Params, Any]) -> _Params:
363393
params, _ = state
364394
return params
365395

366-
return _NumPyroOptim(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)
396+
return _NumPyroOptim(
397+
lambda x, y, z: (x, y, z),
398+
init_fn,
399+
update_fn,
400+
get_params_fn,
401+
update_with_value=True,
402+
)

test/test_optimizers.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,32 @@
1212

1313
try:
1414
import optax
15+
import optax.contrib
1516

1617
# the optimizer test is parameterized by different optax optimizers, but we have
1718
# to define them here to ensure that `optax` is defined. pytest.mark.parameterize
1819
# decorators are run even if tests are skipped at the top of the file.
1920
optax_optimizers = [
20-
(optax.adam, (1e-2,), {}),
21+
(optax.adam, (1e-2,), {}, False),
2122
# clipped adam
22-
(optax.chain, (optax.clip(10.0), optax.adam(1e-2)), {}),
23-
(optax.adagrad, (1e-1,), {}),
23+
(optax.chain, (optax.clip(10.0), optax.adam(1e-2)), {}, False),
24+
(optax.adagrad, (1e-1,), {}, False),
2425
# SGD with momentum
25-
(optax.sgd, (1e-2,), {"momentum": 0.9}),
26-
(optax.rmsprop, (1e-2,), {"decay": 0.95}),
26+
(optax.sgd, (1e-2,), {"momentum": 0.9}, False),
27+
(optax.rmsprop, (1e-2,), {"decay": 0.95}, False),
2728
# RMSProp with momentum
28-
(optax.rmsprop, (1e-4,), {"decay": 0.9, "momentum": 0.9}),
29-
(optax.sgd, (1e-2,), {}),
29+
(optax.rmsprop, (1e-4,), {"decay": 0.9, "momentum": 0.9}, False),
30+
(optax.sgd, (1e-2,), {}, False),
31+
# reduce learning rate on plateau
32+
(
33+
optax.chain,
34+
(
35+
optax.adam(1e-2),
36+
optax.contrib.reduce_on_plateau(patience=5, accumulation_size=200),
37+
),
38+
{},
39+
True,
40+
),
3041
]
3142
except ImportError:
3243
pytestmark = pytest.mark.skip(reason="optax is not installed")
@@ -41,24 +52,27 @@ def loss(params):
4152
def step(opt_state, optim):
4253
params = optim.get_params(opt_state)
4354
g = grad(loss)(params)
44-
return optim.update(g, opt_state)
55+
if optim.update_with_value:
56+
return optim.update(g, opt_state, value=loss(params))
57+
else:
58+
return optim.update(g, opt_state)
4559

4660

4761
@pytest.mark.parametrize(
48-
"optim_class, args, kwargs",
62+
"optim_class, args, kwargs, uses_value_arg",
4963
[
50-
(optim.Adam, (1e-2,), {}),
51-
(optim.ClippedAdam, (1e-2,), {}),
52-
(optim.Adagrad, (1e-1,), {}),
53-
(optim.Momentum, (1e-2, 0.5), {}),
54-
(optim.RMSProp, (1e-2, 0.95), {}),
55-
(optim.RMSPropMomentum, (1e-4,), {}),
56-
(optim.SGD, (1e-2,), {}),
64+
(optim.Adam, (1e-2,), {}, False),
65+
(optim.ClippedAdam, (1e-2,), {}, False),
66+
(optim.Adagrad, (1e-1,), {}, False),
67+
(optim.Momentum, (1e-2, 0.5), {}, False),
68+
(optim.RMSProp, (1e-2, 0.95), {}, False),
69+
(optim.RMSPropMomentum, (1e-4,), {}, False),
70+
(optim.SGD, (1e-2,), {}, False),
5771
]
5872
+ optax_optimizers,
5973
)
6074
@pytest.mark.filterwarnings("ignore:.*tree_multimap:FutureWarning")
61-
def test_optim_multi_params(optim_class, args, kwargs):
75+
def test_optim_multi_params(optim_class, args, kwargs, uses_value_arg):
6276
params = {"x": jnp.array([1.0, 1.0, 1.0]), "y": jnp.array([-1, -1.0, -1.0])}
6377
opt = optim_class(*args, **kwargs)
6478
if not isinstance(opt, optim._NumPyroOptim):
@@ -73,20 +87,20 @@ def test_optim_multi_params(optim_class, args, kwargs):
7387
# note: this is somewhat of a bruteforce test. testing directly from
7488
# _NumpyroOptim would probably be better
7589
@pytest.mark.parametrize(
76-
"optim_class, args, kwargs",
90+
"optim_class, args, kwargs, uses_value_arg",
7791
[
78-
(optim.Adam, (1e-2,), {}),
79-
(optim.ClippedAdam, (1e-2,), {}),
80-
(optim.Adagrad, (1e-1,), {}),
81-
(optim.Momentum, (1e-2, 0.5), {}),
82-
(optim.RMSProp, (1e-2, 0.95), {}),
83-
(optim.RMSPropMomentum, (1e-4,), {}),
84-
(optim.SGD, (1e-2,), {}),
92+
(optim.Adam, (1e-2,), {}, False),
93+
(optim.ClippedAdam, (1e-2,), {}, False),
94+
(optim.Adagrad, (1e-1,), {}, False),
95+
(optim.Momentum, (1e-2, 0.5), {}, False),
96+
(optim.RMSProp, (1e-2, 0.95), {}, False),
97+
(optim.RMSPropMomentum, (1e-4,), {}, False),
98+
(optim.SGD, (1e-2,), {}, False),
8599
]
86100
+ optax_optimizers,
87101
)
88102
@pytest.mark.filterwarnings("ignore:.*tree_multimap:FutureWarning")
89-
def test_numpyrooptim_no_double_jit(optim_class, args, kwargs):
103+
def test_numpyrooptim_no_double_jit(optim_class, args, kwargs, uses_value_arg):
90104
opt = optim_class(*args, **kwargs)
91105
if not isinstance(opt, optim._NumPyroOptim):
92106
opt = optim.optax_to_numpyro(opt)
@@ -99,11 +113,18 @@ def my_fn(state, g):
99113
nonlocal my_fn_calls
100114
my_fn_calls += 1
101115

102-
state = opt.update(g, state)
116+
if opt.update_with_value:
117+
state = opt.update(g, state, value=0.01)
118+
else:
119+
state = opt.update(g, state)
103120
return state
104121

105122
state = my_fn(state, jnp.ones(10) * 1.0)
106123
state = my_fn(state, jnp.ones(10) * 2.0)
107124
state = my_fn(state, jnp.ones(10) * 3.0)
108125

109-
assert my_fn_calls == 1
126+
if uses_value_arg:
127+
# Dtype is different on the first call vs the rest of the calls
128+
assert my_fn_calls == 2
129+
else:
130+
assert my_fn_calls == 1

0 commit comments

Comments
 (0)