12
12
13
13
try :
14
14
import optax
15
+ import optax .contrib
15
16
16
17
# the optimizer test is parameterized by different optax optimizers, but we have
17
18
# to define them here to ensure that `optax` is defined. pytest.mark.parameterize
18
19
# decorators are run even if tests are skipped at the top of the file.
19
20
optax_optimizers = [
20
- (optax .adam , (1e-2 ,), {}),
21
+ (optax .adam , (1e-2 ,), {}, False ),
21
22
# clipped adam
22
- (optax .chain , (optax .clip (10.0 ), optax .adam (1e-2 )), {}),
23
- (optax .adagrad , (1e-1 ,), {}),
23
+ (optax .chain , (optax .clip (10.0 ), optax .adam (1e-2 )), {}, False ),
24
+ (optax .adagrad , (1e-1 ,), {}, False ),
24
25
# SGD with momentum
25
- (optax .sgd , (1e-2 ,), {"momentum" : 0.9 }),
26
- (optax .rmsprop , (1e-2 ,), {"decay" : 0.95 }),
26
+ (optax .sgd , (1e-2 ,), {"momentum" : 0.9 }, False ),
27
+ (optax .rmsprop , (1e-2 ,), {"decay" : 0.95 }, False ),
27
28
# RMSProp with momentum
28
- (optax .rmsprop , (1e-4 ,), {"decay" : 0.9 , "momentum" : 0.9 }),
29
- (optax .sgd , (1e-2 ,), {}),
29
+ (optax .rmsprop , (1e-4 ,), {"decay" : 0.9 , "momentum" : 0.9 }, False ),
30
+ (optax .sgd , (1e-2 ,), {}, False ),
31
+ # reduce learning rate on plateau
32
+ (
33
+ optax .chain ,
34
+ (
35
+ optax .adam (1e-2 ),
36
+ optax .contrib .reduce_on_plateau (patience = 5 , accumulation_size = 200 ),
37
+ ),
38
+ {},
39
+ True ,
40
+ ),
30
41
]
31
42
except ImportError :
32
43
pytestmark = pytest .mark .skip (reason = "optax is not installed" )
@@ -41,24 +52,27 @@ def loss(params):
41
52
def step (opt_state , optim ):
42
53
params = optim .get_params (opt_state )
43
54
g = grad (loss )(params )
44
- return optim .update (g , opt_state )
55
+ if optim .update_with_value :
56
+ return optim .update (g , opt_state , value = loss (params ))
57
+ else :
58
+ return optim .update (g , opt_state )
45
59
46
60
47
61
@pytest .mark .parametrize (
48
- "optim_class, args, kwargs" ,
62
+ "optim_class, args, kwargs, uses_value_arg " ,
49
63
[
50
- (optim .Adam , (1e-2 ,), {}),
51
- (optim .ClippedAdam , (1e-2 ,), {}),
52
- (optim .Adagrad , (1e-1 ,), {}),
53
- (optim .Momentum , (1e-2 , 0.5 ), {}),
54
- (optim .RMSProp , (1e-2 , 0.95 ), {}),
55
- (optim .RMSPropMomentum , (1e-4 ,), {}),
56
- (optim .SGD , (1e-2 ,), {}),
64
+ (optim .Adam , (1e-2 ,), {}, False ),
65
+ (optim .ClippedAdam , (1e-2 ,), {}, False ),
66
+ (optim .Adagrad , (1e-1 ,), {}, False ),
67
+ (optim .Momentum , (1e-2 , 0.5 ), {}, False ),
68
+ (optim .RMSProp , (1e-2 , 0.95 ), {}, False ),
69
+ (optim .RMSPropMomentum , (1e-4 ,), {}, False ),
70
+ (optim .SGD , (1e-2 ,), {}, False ),
57
71
]
58
72
+ optax_optimizers ,
59
73
)
60
74
@pytest .mark .filterwarnings ("ignore:.*tree_multimap:FutureWarning" )
61
- def test_optim_multi_params (optim_class , args , kwargs ):
75
+ def test_optim_multi_params (optim_class , args , kwargs , uses_value_arg ):
62
76
params = {"x" : jnp .array ([1.0 , 1.0 , 1.0 ]), "y" : jnp .array ([- 1 , - 1.0 , - 1.0 ])}
63
77
opt = optim_class (* args , ** kwargs )
64
78
if not isinstance (opt , optim ._NumPyroOptim ):
@@ -73,20 +87,20 @@ def test_optim_multi_params(optim_class, args, kwargs):
73
87
# note: this is somewhat of a bruteforce test. testing directly from
74
88
# _NumpyroOptim would probably be better
75
89
@pytest .mark .parametrize (
76
- "optim_class, args, kwargs" ,
90
+ "optim_class, args, kwargs, uses_value_arg " ,
77
91
[
78
- (optim .Adam , (1e-2 ,), {}),
79
- (optim .ClippedAdam , (1e-2 ,), {}),
80
- (optim .Adagrad , (1e-1 ,), {}),
81
- (optim .Momentum , (1e-2 , 0.5 ), {}),
82
- (optim .RMSProp , (1e-2 , 0.95 ), {}),
83
- (optim .RMSPropMomentum , (1e-4 ,), {}),
84
- (optim .SGD , (1e-2 ,), {}),
92
+ (optim .Adam , (1e-2 ,), {}, False ),
93
+ (optim .ClippedAdam , (1e-2 ,), {}, False ),
94
+ (optim .Adagrad , (1e-1 ,), {}, False ),
95
+ (optim .Momentum , (1e-2 , 0.5 ), {}, False ),
96
+ (optim .RMSProp , (1e-2 , 0.95 ), {}, False ),
97
+ (optim .RMSPropMomentum , (1e-4 ,), {}, False ),
98
+ (optim .SGD , (1e-2 ,), {}, False ),
85
99
]
86
100
+ optax_optimizers ,
87
101
)
88
102
@pytest .mark .filterwarnings ("ignore:.*tree_multimap:FutureWarning" )
89
- def test_numpyrooptim_no_double_jit (optim_class , args , kwargs ):
103
+ def test_numpyrooptim_no_double_jit (optim_class , args , kwargs , uses_value_arg ):
90
104
opt = optim_class (* args , ** kwargs )
91
105
if not isinstance (opt , optim ._NumPyroOptim ):
92
106
opt = optim .optax_to_numpyro (opt )
@@ -99,11 +113,18 @@ def my_fn(state, g):
99
113
nonlocal my_fn_calls
100
114
my_fn_calls += 1
101
115
102
- state = opt .update (g , state )
116
+ if opt .update_with_value :
117
+ state = opt .update (g , state , value = 0.01 )
118
+ else :
119
+ state = opt .update (g , state )
103
120
return state
104
121
105
122
state = my_fn (state , jnp .ones (10 ) * 1.0 )
106
123
state = my_fn (state , jnp .ones (10 ) * 2.0 )
107
124
state = my_fn (state , jnp .ones (10 ) * 3.0 )
108
125
109
- assert my_fn_calls == 1
126
+ if uses_value_arg :
127
+ # Dtype is different on the first call vs the rest of the calls
128
+ assert my_fn_calls == 2
129
+ else :
130
+ assert my_fn_calls == 1
0 commit comments