Skip to content

Commit f279e93

Browse files
authored
Bug fixes with variable handling in LossScaleOptimizer. (#21706)
`overwrite_with_gradient` would be ineffective on JAX in real-world conditions, i.e. within `model.fit`. This is because in the training loop, `stateless_apply` is passed `trainable_variables` as arrays containing the values of the trainable variables, not the variables themselves. Instead, we have to inspect the variables. `apply(grads)` without the `trainable_variables` argument passed in would not apply anything. This is because the code uses `self._trainable_variables`. But this was an empty array for `LossScaleOptimizer`. This was fixed by adding `super().build(...)`. Also fail when other arguments from the base optimizer are passed to `LossScaleOptimizer.__init__` since they are not actually supported. They are also no longer returned by `get_config`.
1 parent a62a4a3 commit f279e93

File tree

3 files changed

+186
-26
lines changed

3 files changed

+186
-26
lines changed

keras/src/optimizers/base_optimizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,20 @@ def _backend_increment_gradient_accumulators(self, grads, acc_grads):
631631
g_acc.assign(n_g_acc)
632632

633633
def stateless_apply(self, optimizer_variables, grads, trainable_variables):
634+
"""Stateless version of `apply` that returns modified variables.
635+
636+
Args:
637+
optimizer_variables: list of tensors containing the current values
638+
for the optimizer variables. These are native tensors and not
639+
`keras.Variable`s.
640+
grads: list of gradients to apply.
641+
trainable_variables: list of tensors containing the current values
642+
for the model variables. These are native tensors and not
643+
`keras.Variable`s.
644+
645+
Returns: A tuple containing two list of tensors, the updated
646+
`trainable_variables` and the updated `optimizer_variables`.
647+
"""
634648
self._check_super_called()
635649

636650
if not self.built:

keras/src/optimizers/loss_scale_optimizer.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
inner_optimizer,
4949
initial_scale=2.0**15,
5050
dynamic_growth_steps=2000,
51+
name=None,
5152
**kwargs,
5253
):
5354
if not kwargs.pop("dynamic", True):
@@ -56,7 +57,42 @@ def __init__(
5657
"Instead, simply set `loss_scale_factor` directly on the "
5758
"`inner_optimizer`."
5859
)
59-
super().__init__(learning_rate=0.0, **kwargs)
60+
61+
# Backwards compatibility code for deserialization.
62+
# LossScaleOptimizer used to return all these parameters in `get_config`
63+
# from `super.get_config` even though they are all non-functional. We
64+
# no longer let user set them, but we have to allow the default values
65+
# to be passed during deserialization to support older models.
66+
base_optimizer_defaults = {
67+
"weight_decay": None,
68+
"clipnorm": None,
69+
"global_clipnorm": None,
70+
"clipvalue": None,
71+
"use_ema": False,
72+
"ema_momentum": 0.99,
73+
"ema_overwrite_frequency": None,
74+
"loss_scale_factor": None,
75+
"gradient_accumulation_steps": None,
76+
}
77+
for arg_name, default_value in base_optimizer_defaults.items():
78+
if arg_name not in kwargs:
79+
continue
80+
arg_value = kwargs.pop(arg_name)
81+
if (
82+
default_value is None and arg_value is not None
83+
) or arg_value != default_value:
84+
raise ValueError(
85+
f"LossScaleOptimizer does not support `{arg_name}`. "
86+
f"Instead, set `{arg_name}` on the `inner_optimizer`."
87+
)
88+
89+
if kwargs:
90+
raise ValueError(
91+
"LossScaleOptimizer does not support arguments: "
92+
f"`{'`, `'.join(kwargs.keys())}`."
93+
)
94+
95+
super().__init__(learning_rate=0.0, name=name)
6096
self.inner_optimizer = inner_optimizer
6197
self.initial_scale = initial_scale
6298
self.dynamic_growth_steps = dynamic_growth_steps
@@ -81,7 +117,7 @@ def build(self, var_list):
81117
name="dynamic_scale",
82118
)
83119
self.inner_optimizer.build(var_list)
84-
self.built = True
120+
super().build(var_list)
85121

86122
@property
87123
def variables(self):
@@ -136,7 +172,7 @@ def increment():
136172
g
137173
if g is None or self._overwrite_variable_with_gradient(v)
138174
else ops.divide(g, scale)
139-
for g, v in zip(grads, trainable_variables)
175+
for g, v in zip(grads, self._trainable_variables)
140176
]
141177
(
142178
new_trainable_variables,
@@ -284,19 +320,16 @@ def finalize_variable_values(self, var_list):
284320
self.inner_optimizer.finalize_variable_values(var_list)
285321

286322
def get_config(self):
287-
config = super().get_config()
323+
# Do not use super().get_config() as only "name" is supported.
288324
inner_optimizer_config = serialization_lib.serialize_keras_object(
289325
self.inner_optimizer
290326
)
291-
config.update(
292-
{
293-
"inner_optimizer": inner_optimizer_config,
294-
"initial_scale": self.initial_scale,
295-
"dynamic_growth_steps": self.dynamic_growth_steps,
296-
}
297-
)
298-
del config["learning_rate"]
299-
return config
327+
return {
328+
"name": self.name,
329+
"inner_optimizer": inner_optimizer_config,
330+
"initial_scale": self.initial_scale,
331+
"dynamic_growth_steps": self.dynamic_growth_steps,
332+
}
300333

301334
@classmethod
302335
def from_config(cls, config, custom_objects=None):

keras/src/optimizers/loss_scale_optimizer_test.py

Lines changed: 126 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ def test_config(self):
2929
optimizer = LossScaleOptimizer(inner_optimizer)
3030
self.run_class_serialization_test(optimizer)
3131

32+
def test_apply_with_no_vars(self):
33+
self._skip_test_for_stateless(False)
34+
35+
inner_optimizer = SGD(learning_rate=0.5)
36+
optimizer = LossScaleOptimizer(inner_optimizer)
37+
grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale]
38+
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
39+
optimizer.build(vars)
40+
optimizer.apply(grads)
41+
self.assertAllClose(
42+
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
43+
)
44+
3245
@parameterized.named_parameters(("stateless", True), ("stateful", False))
3346
def test_finite_step(self, stateless):
3447
self._skip_test_for_stateless(stateless)
@@ -40,7 +53,9 @@ def test_finite_step(self, stateless):
4053
if stateless:
4154
optimizer.build(vars)
4255
vars, _ = optimizer.stateless_apply(
43-
optimizer.variables, grads, vars
56+
[v.value for v in optimizer.variables],
57+
grads,
58+
[v.value for v in vars],
4459
)
4560
else:
4661
optimizer.apply(grads, vars)
@@ -60,7 +75,9 @@ def test_finite_step_with_inner_loss_scale(self, stateless):
6075
if stateless:
6176
optimizer.build(vars)
6277
vars, _ = optimizer.stateless_apply(
63-
optimizer.variables, grads, vars
78+
[v.value for v in optimizer.variables],
79+
grads,
80+
[v.value for v in vars],
6481
)
6582
else:
6683
optimizer.apply(grads, vars)
@@ -79,7 +96,9 @@ def test_infinite_step(self, stateless):
7996
if stateless:
8097
optimizer.build(vars)
8198
vars, _ = optimizer.stateless_apply(
82-
optimizer.variables, grads, vars
99+
[v.value for v in optimizer.variables],
100+
grads,
101+
[v.value for v in vars],
83102
)
84103
else:
85104
optimizer.apply(grads, vars)
@@ -98,7 +117,9 @@ def test_finite_step_with_overwrite(self, stateless):
98117
if stateless:
99118
optimizer.build(vars)
100119
vars, _ = optimizer.stateless_apply(
101-
optimizer.variables, grads, vars
120+
[v.value for v in optimizer.variables],
121+
grads,
122+
[v.value for v in vars],
102123
)
103124
else:
104125
optimizer.apply(grads, vars)
@@ -112,12 +133,14 @@ def test_downscaling(self, stateless):
112133
optimizer = LossScaleOptimizer(inner_optimizer, initial_scale=400.0)
113134
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
114135
optimizer.build(vars)
115-
opt_vars = optimizer.variables
136+
opt_var_values = [v.value for v in optimizer.variables]
116137
grads = [ops.array([np.inf, np.inf, np.inf, np.inf])]
117138
for _ in range(4):
118139
if stateless:
119-
_, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars)
120-
for ref_v, v in zip(optimizer.variables, opt_vars):
140+
_, opt_var_values = optimizer.stateless_apply(
141+
opt_var_values, grads, [v.value for v in vars]
142+
)
143+
for ref_v, v in zip(optimizer.variables, opt_var_values):
121144
ref_v.assign(v)
122145
else:
123146
optimizer.apply(grads, vars)
@@ -135,12 +158,14 @@ def test_upscaling(self, stateless):
135158
)
136159
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
137160
optimizer.build(vars)
138-
opt_vars = optimizer.variables
161+
opt_var_values = [v.value for v in optimizer.variables]
139162
grads = [ops.array([1.0, 6.0, 7.0, 2.0])]
140163
for _ in range(8):
141164
if stateless:
142-
_, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars)
143-
for ref_v, v in zip(optimizer.variables, opt_vars):
165+
_, opt_var_values = optimizer.stateless_apply(
166+
opt_var_values, grads, [v.value for v in vars]
167+
)
168+
for ref_v, v in zip(optimizer.variables, opt_var_values):
144169
ref_v.assign(v)
145170
else:
146171
optimizer.apply(grads, vars)
@@ -154,16 +179,104 @@ def test_iterations_update(self, stateless):
154179
optimizer = LossScaleOptimizer(inner_optimizer)
155180
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
156181
optimizer.build(vars)
157-
opt_vars = optimizer.variables
182+
opt_var_values = [v.value for v in optimizer.variables]
158183
grads = [ops.array([1.0, 6.0, 7.0, 2.0])]
159184

160185
self.assertEqual(optimizer.iterations.value, 0)
161186

162187
for i in range(3):
163188
if stateless:
164-
_, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars)
165-
for ref_v, v in zip(optimizer.variables, opt_vars):
189+
_, opt_var_values = optimizer.stateless_apply(
190+
opt_var_values, grads, [v.value for v in vars]
191+
)
192+
for ref_v, v in zip(optimizer.variables, opt_var_values):
166193
ref_v.assign(v)
167194
else:
168195
optimizer.apply(grads, vars)
169196
self.assertEqual(optimizer.iterations.value, i + 1)
197+
198+
def test_serialization(self):
199+
inner_optimizer = SGD(learning_rate=0.5)
200+
optimizer = LossScaleOptimizer(
201+
inner_optimizer,
202+
initial_scale=3.0,
203+
dynamic_growth_steps=2,
204+
name="test_opt",
205+
)
206+
config = optimizer.get_config()
207+
self.assertLen(config, 4)
208+
self.assertEqual(config["name"], "test_opt")
209+
self.assertEqual(config["initial_scale"], 3.0)
210+
self.assertEqual(config["dynamic_growth_steps"], 2)
211+
self.assertIn("inner_optimizer", config)
212+
LossScaleOptimizer.from_config(config)
213+
214+
def test_init_dynamic_arg(self):
215+
inner_optimizer = SGD(learning_rate=0.5)
216+
217+
# dynamic=True is supported
218+
LossScaleOptimizer(inner_optimizer, dynamic=True)
219+
220+
# dynamic=False is not supported
221+
with self.assertRaisesRegex(ValueError, "set `loss_scale_factor`"):
222+
LossScaleOptimizer(inner_optimizer, dynamic=False)
223+
224+
def test_init_unsupported_arg(self):
225+
inner_optimizer = SGD(learning_rate=0.5)
226+
with self.assertRaisesRegex(ValueError, "arguments: `foo`, `bar`"):
227+
LossScaleOptimizer(inner_optimizer, foo=True, bar=3)
228+
229+
@parameterized.named_parameters(
230+
("weight_decay", "weight_decay", 0.5),
231+
("clipnorm", "clipnorm", 0.5),
232+
("global_clipnorm", "global_clipnorm", 0.5),
233+
("clipvalue", "clipvalue", 0.5),
234+
("use_ema", "use_ema", True),
235+
("ema_momentum", "ema_momentum", 0.5),
236+
("ema_overwrite_frequency", "ema_overwrite_frequency", 2),
237+
("loss_scale_factor", "loss_scale_factor", 0.5),
238+
("gradient_accumulation_steps", "gradient_accumulation_steps", 2),
239+
)
240+
def test_init_base_optimizer_unsupported_args(self, arg_name, arg_value):
241+
inner_optimizer = SGD(learning_rate=0.5)
242+
with self.assertRaisesRegex(ValueError, "on the `inner_optimizer`"):
243+
LossScaleOptimizer(inner_optimizer, **{arg_name: arg_value})
244+
245+
def test_deserialization_backwards_compatibility(self):
246+
# Test deserializing with a config that has all the unsupported
247+
# arguments from the base optimizer (which are no longer serialized)
248+
config = {
249+
"name": "loss_scale_optimizer",
250+
"weight_decay": None,
251+
"clipnorm": None,
252+
"global_clipnorm": None,
253+
"clipvalue": None,
254+
"use_ema": False,
255+
"ema_momentum": 0.99,
256+
"ema_overwrite_frequency": None,
257+
"loss_scale_factor": None,
258+
"gradient_accumulation_steps": None,
259+
"inner_optimizer": {
260+
"module": "keras.optimizers",
261+
"class_name": "SGD",
262+
"config": {
263+
"name": "SGD",
264+
"learning_rate": 0.5,
265+
"weight_decay": None,
266+
"clipnorm": None,
267+
"global_clipnorm": None,
268+
"clipvalue": None,
269+
"use_ema": False,
270+
"ema_momentum": 0.99,
271+
"ema_overwrite_frequency": None,
272+
"loss_scale_factor": None,
273+
"gradient_accumulation_steps": None,
274+
"momentum": 0.0,
275+
"nesterov": False,
276+
},
277+
"registered_name": None,
278+
},
279+
"initial_scale": 2.0,
280+
"dynamic_growth_steps": 2,
281+
}
282+
LossScaleOptimizer.from_config(config)

0 commit comments

Comments
 (0)