@@ -29,6 +29,19 @@ def test_config(self):
29
29
optimizer = LossScaleOptimizer (inner_optimizer )
30
30
self .run_class_serialization_test (optimizer )
31
31
32
+ def test_apply_with_no_vars (self ):
33
+ self ._skip_test_for_stateless (False )
34
+
35
+ inner_optimizer = SGD (learning_rate = 0.5 )
36
+ optimizer = LossScaleOptimizer (inner_optimizer )
37
+ grads = [ops .array ([1.0 , 6.0 , 7.0 , 2.0 ]) * optimizer .initial_scale ]
38
+ vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
39
+ optimizer .build (vars )
40
+ optimizer .apply (grads )
41
+ self .assertAllClose (
42
+ vars , [[0.5 , - 1.0 , - 0.5 , 3.0 ]], rtol = 1e-4 , atol = 1e-4
43
+ )
44
+
32
45
@parameterized .named_parameters (("stateless" , True ), ("stateful" , False ))
33
46
def test_finite_step (self , stateless ):
34
47
self ._skip_test_for_stateless (stateless )
@@ -40,7 +53,9 @@ def test_finite_step(self, stateless):
40
53
if stateless :
41
54
optimizer .build (vars )
42
55
vars , _ = optimizer .stateless_apply (
43
- optimizer .variables , grads , vars
56
+ [v .value for v in optimizer .variables ],
57
+ grads ,
58
+ [v .value for v in vars ],
44
59
)
45
60
else :
46
61
optimizer .apply (grads , vars )
@@ -60,7 +75,9 @@ def test_finite_step_with_inner_loss_scale(self, stateless):
60
75
if stateless :
61
76
optimizer .build (vars )
62
77
vars , _ = optimizer .stateless_apply (
63
- optimizer .variables , grads , vars
78
+ [v .value for v in optimizer .variables ],
79
+ grads ,
80
+ [v .value for v in vars ],
64
81
)
65
82
else :
66
83
optimizer .apply (grads , vars )
@@ -79,7 +96,9 @@ def test_infinite_step(self, stateless):
79
96
if stateless :
80
97
optimizer .build (vars )
81
98
vars , _ = optimizer .stateless_apply (
82
- optimizer .variables , grads , vars
99
+ [v .value for v in optimizer .variables ],
100
+ grads ,
101
+ [v .value for v in vars ],
83
102
)
84
103
else :
85
104
optimizer .apply (grads , vars )
@@ -98,7 +117,9 @@ def test_finite_step_with_overwrite(self, stateless):
98
117
if stateless :
99
118
optimizer .build (vars )
100
119
vars , _ = optimizer .stateless_apply (
101
- optimizer .variables , grads , vars
120
+ [v .value for v in optimizer .variables ],
121
+ grads ,
122
+ [v .value for v in vars ],
102
123
)
103
124
else :
104
125
optimizer .apply (grads , vars )
@@ -112,12 +133,14 @@ def test_downscaling(self, stateless):
112
133
optimizer = LossScaleOptimizer (inner_optimizer , initial_scale = 400.0 )
113
134
vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
114
135
optimizer .build (vars )
115
- opt_vars = optimizer .variables
136
+ opt_var_values = [ v . value for v in optimizer .variables ]
116
137
grads = [ops .array ([np .inf , np .inf , np .inf , np .inf ])]
117
138
for _ in range (4 ):
118
139
if stateless :
119
- _ , opt_vars = optimizer .stateless_apply (opt_vars , grads , vars )
120
- for ref_v , v in zip (optimizer .variables , opt_vars ):
140
+ _ , opt_var_values = optimizer .stateless_apply (
141
+ opt_var_values , grads , [v .value for v in vars ]
142
+ )
143
+ for ref_v , v in zip (optimizer .variables , opt_var_values ):
121
144
ref_v .assign (v )
122
145
else :
123
146
optimizer .apply (grads , vars )
@@ -135,12 +158,14 @@ def test_upscaling(self, stateless):
135
158
)
136
159
vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
137
160
optimizer .build (vars )
138
- opt_vars = optimizer .variables
161
+ opt_var_values = [ v . value for v in optimizer .variables ]
139
162
grads = [ops .array ([1.0 , 6.0 , 7.0 , 2.0 ])]
140
163
for _ in range (8 ):
141
164
if stateless :
142
- _ , opt_vars = optimizer .stateless_apply (opt_vars , grads , vars )
143
- for ref_v , v in zip (optimizer .variables , opt_vars ):
165
+ _ , opt_var_values = optimizer .stateless_apply (
166
+ opt_var_values , grads , [v .value for v in vars ]
167
+ )
168
+ for ref_v , v in zip (optimizer .variables , opt_var_values ):
144
169
ref_v .assign (v )
145
170
else :
146
171
optimizer .apply (grads , vars )
@@ -154,16 +179,104 @@ def test_iterations_update(self, stateless):
154
179
optimizer = LossScaleOptimizer (inner_optimizer )
155
180
vars = [backend .Variable ([1.0 , 2.0 , 3.0 , 4.0 ])]
156
181
optimizer .build (vars )
157
- opt_vars = optimizer .variables
182
+ opt_var_values = [ v . value for v in optimizer .variables ]
158
183
grads = [ops .array ([1.0 , 6.0 , 7.0 , 2.0 ])]
159
184
160
185
self .assertEqual (optimizer .iterations .value , 0 )
161
186
162
187
for i in range (3 ):
163
188
if stateless :
164
- _ , opt_vars = optimizer .stateless_apply (opt_vars , grads , vars )
165
- for ref_v , v in zip (optimizer .variables , opt_vars ):
189
+ _ , opt_var_values = optimizer .stateless_apply (
190
+ opt_var_values , grads , [v .value for v in vars ]
191
+ )
192
+ for ref_v , v in zip (optimizer .variables , opt_var_values ):
166
193
ref_v .assign (v )
167
194
else :
168
195
optimizer .apply (grads , vars )
169
196
self .assertEqual (optimizer .iterations .value , i + 1 )
197
+
198
+ def test_serialization (self ):
199
+ inner_optimizer = SGD (learning_rate = 0.5 )
200
+ optimizer = LossScaleOptimizer (
201
+ inner_optimizer ,
202
+ initial_scale = 3.0 ,
203
+ dynamic_growth_steps = 2 ,
204
+ name = "test_opt" ,
205
+ )
206
+ config = optimizer .get_config ()
207
+ self .assertLen (config , 4 )
208
+ self .assertEqual (config ["name" ], "test_opt" )
209
+ self .assertEqual (config ["initial_scale" ], 3.0 )
210
+ self .assertEqual (config ["dynamic_growth_steps" ], 2 )
211
+ self .assertIn ("inner_optimizer" , config )
212
+ LossScaleOptimizer .from_config (config )
213
+
214
+ def test_init_dynamic_arg (self ):
215
+ inner_optimizer = SGD (learning_rate = 0.5 )
216
+
217
+ # dynamic=True is supported
218
+ LossScaleOptimizer (inner_optimizer , dynamic = True )
219
+
220
+ # dynamic=False is not supported
221
+ with self .assertRaisesRegex (ValueError , "set `loss_scale_factor`" ):
222
+ LossScaleOptimizer (inner_optimizer , dynamic = False )
223
+
224
+ def test_init_unsupported_arg (self ):
225
+ inner_optimizer = SGD (learning_rate = 0.5 )
226
+ with self .assertRaisesRegex (ValueError , "arguments: `foo`, `bar`" ):
227
+ LossScaleOptimizer (inner_optimizer , foo = True , bar = 3 )
228
+
229
+ @parameterized .named_parameters (
230
+ ("weight_decay" , "weight_decay" , 0.5 ),
231
+ ("clipnorm" , "clipnorm" , 0.5 ),
232
+ ("global_clipnorm" , "global_clipnorm" , 0.5 ),
233
+ ("clipvalue" , "clipvalue" , 0.5 ),
234
+ ("use_ema" , "use_ema" , True ),
235
+ ("ema_momentum" , "ema_momentum" , 0.5 ),
236
+ ("ema_overwrite_frequency" , "ema_overwrite_frequency" , 2 ),
237
+ ("loss_scale_factor" , "loss_scale_factor" , 0.5 ),
238
+ ("gradient_accumulation_steps" , "gradient_accumulation_steps" , 2 ),
239
+ )
240
+ def test_init_base_optimizer_unsupported_args (self , arg_name , arg_value ):
241
+ inner_optimizer = SGD (learning_rate = 0.5 )
242
+ with self .assertRaisesRegex (ValueError , "on the `inner_optimizer`" ):
243
+ LossScaleOptimizer (inner_optimizer , ** {arg_name : arg_value })
244
+
245
+ def test_deserialization_backwards_compatibility (self ):
246
+ # Test deserializing with a config that has all the unsupported
247
+ # arguments from the base optimizer (which are no longer serialized)
248
+ config = {
249
+ "name" : "loss_scale_optimizer" ,
250
+ "weight_decay" : None ,
251
+ "clipnorm" : None ,
252
+ "global_clipnorm" : None ,
253
+ "clipvalue" : None ,
254
+ "use_ema" : False ,
255
+ "ema_momentum" : 0.99 ,
256
+ "ema_overwrite_frequency" : None ,
257
+ "loss_scale_factor" : None ,
258
+ "gradient_accumulation_steps" : None ,
259
+ "inner_optimizer" : {
260
+ "module" : "keras.optimizers" ,
261
+ "class_name" : "SGD" ,
262
+ "config" : {
263
+ "name" : "SGD" ,
264
+ "learning_rate" : 0.5 ,
265
+ "weight_decay" : None ,
266
+ "clipnorm" : None ,
267
+ "global_clipnorm" : None ,
268
+ "clipvalue" : None ,
269
+ "use_ema" : False ,
270
+ "ema_momentum" : 0.99 ,
271
+ "ema_overwrite_frequency" : None ,
272
+ "loss_scale_factor" : None ,
273
+ "gradient_accumulation_steps" : None ,
274
+ "momentum" : 0.0 ,
275
+ "nesterov" : False ,
276
+ },
277
+ "registered_name" : None ,
278
+ },
279
+ "initial_scale" : 2.0 ,
280
+ "dynamic_growth_steps" : 2 ,
281
+ }
282
+ LossScaleOptimizer .from_config (config )
0 commit comments