Skip to content

Commit 401d536

Browse files
committed
Add value/norm/global norm clipping in optimizers.
1 parent 35351c4 commit 401d536

File tree

4 files changed

+86
-69
lines changed

4 files changed

+86
-69
lines changed

keras_core/optimizers/adamw_test.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,13 @@ def test_correctness_with_golden(self):
7474
optimizer.apply_gradients(zip([grads], [x]))
7575

7676
def test_clip_norm(self):
77-
# TODO: implement clip_gradients, then uncomment
78-
pass
79-
80-
# optimizer = AdamW(clipnorm=1)
81-
# grad = [np.array([100.0, 100.0])]
82-
# clipped_grad = optimizer._clip_gradients(grad)
83-
# self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
77+
optimizer = AdamW(clipnorm=1)
78+
grad = [np.array([100.0, 100.0])]
79+
clipped_grad = optimizer._clip_gradients(grad)
80+
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
8481

8582
def test_clip_value(self):
86-
# TODO: implement clip_gradients, then uncomment
87-
pass
88-
89-
# optimizer = AdamW(clipvalue=1)
90-
# grad = [np.array([100.0, 100.0])]
91-
# clipped_grad = optimizer._clip_gradients(grad)
92-
# self.assertAllClose(clipped_grad[0], [1.0, 1.0])
83+
optimizer = AdamW(clipvalue=1)
84+
grad = [np.array([100.0, 100.0])]
85+
clipped_grad = optimizer._clip_gradients(grad)
86+
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

keras_core/optimizers/optimizer.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras_core import operations as ops
99
from keras_core.api_export import keras_core_export
1010
from keras_core.optimizers.schedules import learning_rate_schedule
11+
from keras_core.saving import serialization_lib
1112
from keras_core.utils.naming import auto_name
1213
from keras_core.utils.tracking import Tracker
1314

@@ -329,34 +330,27 @@ def _filter_empty_gradients(self, grads_and_vars):
329330

330331
def _clip_gradients(self, grads):
331332
if self.clipnorm and self.clipnorm > 0:
332-
raise NotImplementedError # TODO
333-
# clipped_grads = []
334-
# for g in grads:
335-
# if g is None:
336-
# clipped_grads.append(g)
337-
# else:
338-
# clipped_grads.append(tf.clip_by_norm(g, self.clipnorm))
339-
# return clipped_grads
333+
clipped_grads = []
334+
for g in grads:
335+
if g is None:
336+
clipped_grads.append(g)
337+
else:
338+
clipped_grads.append(clip_by_norm(g, self.clipnorm))
339+
return clipped_grads
340340

341341
if self.global_clipnorm and self.global_clipnorm > 0:
342-
raise NotImplementedError # TODO
343-
# return tf.clip_by_global_norm(grads, self.global_clipnorm)[0]
342+
return clip_by_global_norm(grads, self.global_clipnorm)
344343

345344
if self.clipvalue and self.clipvalue > 0:
346-
raise NotImplementedError # TODO
347-
# clipped_grads = []
348-
# for g in grads:
349-
# if g is None:
350-
# clipped_grads.append(g)
351-
# else:
352-
# clipped_grads.append(
353-
# tf.clip_by_value(
354-
# g,
355-
# clip_value_min=-self.clipvalue,
356-
# clip_value_max=self.clipvalue,
357-
# )
358-
# )
359-
# return clipped_grads
345+
clipped_grads = []
346+
for g in grads:
347+
if g is None:
348+
clipped_grads.append(g)
349+
else:
350+
clipped_grads.append(
351+
ops.clip(g, -self.clipvalue, self.clipvalue)
352+
)
353+
return clipped_grads
360354
return grads
361355

362356
def exclude_from_weight_decay(self, var_list=None, var_names=None):
@@ -491,8 +485,9 @@ def get_config(self):
491485
elif ops.is_tensor(self._learning_rate):
492486
learning_rate = float(self._learning_rate)
493487
elif callable(self._learning_rate):
494-
# TODO: serialize custom object
495-
learning_rate = self._learning_rate
488+
learning_rate = serialization_lib.serialize_keras_object(
489+
self._learning_rate
490+
)
496491

497492
config = {
498493
"name": self.name,
@@ -524,7 +519,9 @@ def from_config(cls, config, custom_objects=None):
524519
"""
525520
if "learning_rate" in config:
526521
if isinstance(config["learning_rate"], dict):
527-
config["learning_rate"] = learning_rate_schedule.deserialize(
522+
config[
523+
"learning_rate"
524+
] = serialization_lib.deserialize_keras_object(
528525
config["learning_rate"], custom_objects=custom_objects
529526
)
530527
return cls(**config)
@@ -561,3 +558,41 @@ def from_config(cls, config, custom_objects=None):
561558
variables in-place). When using the built-in `fit()` training loop,
562559
this happens automatically after the last epoch,
563560
and you don't need to do anything."""
561+
562+
563+
def clip_by_norm(values, clip_norm, axes=None):
564+
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
565+
l2sum = ops.sum(values * values, axes, keepdims=True)
566+
pred = l2sum > 0
567+
# Two-tap tf.where trick to bypass NaN gradients
568+
l2sum_safe = ops.where(pred, l2sum, ops.ones_like(l2sum))
569+
l2norm = ops.where(pred, ops.sqrt(l2sum_safe), l2sum)
570+
intermediate = values * clip_norm
571+
values_clip = intermediate / ops.maximum(l2norm, clip_norm)
572+
return values_clip
573+
574+
575+
def global_norm(value_list):
576+
"""Computes the global norm of multiple tensors."""
577+
squared_norms = []
578+
for v in value_list:
579+
if v is not None:
580+
squared_norms.append(ops.sum(ops.square(v)))
581+
squared_norm = ops.sum(ops.stack(squared_norms))
582+
return ops.sqrt(squared_norm)
583+
584+
585+
def clip_by_global_norm(value_list, clip_norm):
586+
use_norm = global_norm(value_list)
587+
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
588+
scale_for_finite = clip_norm * ops.minimum(1.0 / use_norm, 1.0 / clip_norm)
589+
# If use_norm is any finite number, this is a no-op. For inf/-inf/NaN,
590+
# this will make scale NaN.
591+
scale = scale_for_finite + (use_norm - use_norm)
592+
values_clipped = []
593+
for v in value_list:
594+
if v is None:
595+
values_clipped.append(None)
596+
else:
597+
values_clipped.append(v * scale)
598+
return values_clipped

keras_core/optimizers/rmsprop_test.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,13 @@ def test_correctness_with_golden(self):
6464
optimizer.apply_gradients(zip([grads], [x]))
6565

6666
def test_clip_norm(self):
67-
# TODO: implement clip_gradients, then uncomment
68-
pass
69-
70-
# optimizer = RMSprop(clipnorm=1)
71-
# grad = [np.array([100.0, 100.0])]
72-
# clipped_grad = optimizer._clip_gradients(grad)
73-
# self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
67+
optimizer = RMSprop(clipnorm=1)
68+
grad = [np.array([100.0, 100.0])]
69+
clipped_grad = optimizer._clip_gradients(grad)
70+
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
7471

7572
def test_clip_value(self):
76-
# TODO: implement clip_gradients, then uncomment
77-
pass
78-
79-
# optimizer = RMSprop(clipvalue=1)
80-
# grad = [np.array([100.0, 100.0])]
81-
# clipped_grad = optimizer._clip_gradients(grad)
82-
# self.assertAllClose(clipped_grad[0], [1.0, 1.0])
73+
optimizer = RMSprop(clipvalue=1)
74+
grad = [np.array([100.0, 100.0])]
75+
clipped_grad = optimizer._clip_gradients(grad)
76+
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

keras_core/optimizers/sgd_test.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,13 @@ def test_correctness_with_golden(self):
7272
optimizer.apply_gradients(zip([grads], [x]))
7373

7474
def test_clip_norm(self):
75-
# TODO: implement clip_gradients, then uncomment
76-
pass
77-
78-
# optimizer = SGD(clipnorm=1)
79-
# grad = [np.array([100.0, 100.0])]
80-
# clipped_grad = optimizer._clip_gradients(grad)
81-
# self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
75+
optimizer = SGD(clipnorm=1)
76+
grad = [np.array([100.0, 100.0])]
77+
clipped_grad = optimizer._clip_gradients(grad)
78+
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])
8279

8380
def test_clip_value(self):
84-
# TODO: implement clip_gradients, then uncomment
85-
pass
86-
87-
# optimizer = SGD(clipvalue=1)
88-
# grad = [np.array([100.0, 100.0])]
89-
# clipped_grad = optimizer._clip_gradients(grad)
90-
# self.assertAllClose(clipped_grad[0], [1.0, 1.0])
81+
optimizer = SGD(clipvalue=1)
82+
grad = [np.array([100.0, 100.0])]
83+
clipped_grad = optimizer._clip_gradients(grad)
84+
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

0 commit comments

Comments
 (0)