|
8 | 8 | from keras_core import operations as ops |
9 | 9 | from keras_core.api_export import keras_core_export |
10 | 10 | from keras_core.optimizers.schedules import learning_rate_schedule |
| 11 | +from keras_core.saving import serialization_lib |
11 | 12 | from keras_core.utils.naming import auto_name |
12 | 13 | from keras_core.utils.tracking import Tracker |
13 | 14 |
|
@@ -329,34 +330,27 @@ def _filter_empty_gradients(self, grads_and_vars): |
329 | 330 |
|
330 | 331 | def _clip_gradients(self, grads): |
331 | 332 | if self.clipnorm and self.clipnorm > 0: |
332 | | - raise NotImplementedError # TODO |
333 | | - # clipped_grads = [] |
334 | | - # for g in grads: |
335 | | - # if g is None: |
336 | | - # clipped_grads.append(g) |
337 | | - # else: |
338 | | - # clipped_grads.append(tf.clip_by_norm(g, self.clipnorm)) |
339 | | - # return clipped_grads |
| 333 | + clipped_grads = [] |
| 334 | + for g in grads: |
| 335 | + if g is None: |
| 336 | + clipped_grads.append(g) |
| 337 | + else: |
| 338 | + clipped_grads.append(clip_by_norm(g, self.clipnorm)) |
| 339 | + return clipped_grads |
340 | 340 |
|
341 | 341 | if self.global_clipnorm and self.global_clipnorm > 0: |
342 | | - raise NotImplementedError # TODO |
343 | | - # return tf.clip_by_global_norm(grads, self.global_clipnorm)[0] |
| 342 | + return clip_by_global_norm(grads, self.global_clipnorm) |
344 | 343 |
|
345 | 344 | if self.clipvalue and self.clipvalue > 0: |
346 | | - raise NotImplementedError # TODO |
347 | | - # clipped_grads = [] |
348 | | - # for g in grads: |
349 | | - # if g is None: |
350 | | - # clipped_grads.append(g) |
351 | | - # else: |
352 | | - # clipped_grads.append( |
353 | | - # tf.clip_by_value( |
354 | | - # g, |
355 | | - # clip_value_min=-self.clipvalue, |
356 | | - # clip_value_max=self.clipvalue, |
357 | | - # ) |
358 | | - # ) |
359 | | - # return clipped_grads |
| 345 | + clipped_grads = [] |
| 346 | + for g in grads: |
| 347 | + if g is None: |
| 348 | + clipped_grads.append(g) |
| 349 | + else: |
| 350 | + clipped_grads.append( |
| 351 | + ops.clip(g, -self.clipvalue, self.clipvalue) |
| 352 | + ) |
| 353 | + return clipped_grads |
360 | 354 | return grads |
361 | 355 |
|
362 | 356 | def exclude_from_weight_decay(self, var_list=None, var_names=None): |
@@ -491,8 +485,9 @@ def get_config(self): |
491 | 485 | elif ops.is_tensor(self._learning_rate): |
492 | 486 | learning_rate = float(self._learning_rate) |
493 | 487 | elif callable(self._learning_rate): |
494 | | - # TODO: serialize custom object |
495 | | - learning_rate = self._learning_rate |
| 488 | + learning_rate = serialization_lib.serialize_keras_object( |
| 489 | + self._learning_rate |
| 490 | + ) |
496 | 491 |
|
497 | 492 | config = { |
498 | 493 | "name": self.name, |
@@ -524,7 +519,9 @@ def from_config(cls, config, custom_objects=None): |
524 | 519 | """ |
525 | 520 | if "learning_rate" in config: |
526 | 521 | if isinstance(config["learning_rate"], dict): |
527 | | - config["learning_rate"] = learning_rate_schedule.deserialize( |
| 522 | + config[ |
| 523 | + "learning_rate" |
| 524 | + ] = serialization_lib.deserialize_keras_object( |
528 | 525 | config["learning_rate"], custom_objects=custom_objects |
529 | 526 | ) |
530 | 527 | return cls(**config) |
@@ -561,3 +558,41 @@ def from_config(cls, config, custom_objects=None): |
561 | 558 | variables in-place). When using the built-in `fit()` training loop, |
562 | 559 | this happens automatically after the last epoch, |
563 | 560 | and you don't need to do anything.""" |
| 561 | + |
| 562 | + |
| 563 | +def clip_by_norm(values, clip_norm, axes=None): |
| 564 | + # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm |
| 565 | + l2sum = ops.sum(values * values, axes, keepdims=True) |
| 566 | + pred = l2sum > 0 |
| 567 | + # Two-tap tf.where trick to bypass NaN gradients |
| 568 | + l2sum_safe = ops.where(pred, l2sum, ops.ones_like(l2sum)) |
| 569 | + l2norm = ops.where(pred, ops.sqrt(l2sum_safe), l2sum) |
| 570 | + intermediate = values * clip_norm |
| 571 | + values_clip = intermediate / ops.maximum(l2norm, clip_norm) |
| 572 | + return values_clip |
| 573 | + |
| 574 | + |
| 575 | +def global_norm(value_list): |
| 576 | + """Computes the global norm of multiple tensors.""" |
| 577 | + squared_norms = [] |
| 578 | + for v in value_list: |
| 579 | + if v is not None: |
| 580 | + squared_norms.append(ops.sum(ops.square(v))) |
| 581 | + squared_norm = ops.sum(ops.stack(squared_norms)) |
| 582 | + return ops.sqrt(squared_norm) |
| 583 | + |
| 584 | + |
| 585 | +def clip_by_global_norm(value_list, clip_norm): |
| 586 | + use_norm = global_norm(value_list) |
| 587 | + # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm |
| 588 | + scale_for_finite = clip_norm * ops.minimum(1.0 / use_norm, 1.0 / clip_norm) |
| 589 | + # If use_norm is any finite number, this is a no-op. For inf/-inf/NaN, |
| 590 | + # this will make scale NaN. |
| 591 | + scale = scale_for_finite + (use_norm - use_norm) |
| 592 | + values_clipped = [] |
| 593 | + for v in value_list: |
| 594 | + if v is None: |
| 595 | + values_clipped.append(None) |
| 596 | + else: |
| 597 | + values_clipped.append(v * scale) |
| 598 | + return values_clipped |
0 commit comments