|
20 | 20 | import numpy as np |
21 | 21 | import tensorflow.compat.v1 as tf |
22 | 22 | import tensorflow.compat.v2 as tf2 |
23 | | -from tensorflow.python.eager import tape as tape_lib # pylint:disable=g-direct-tensorflow-import |
24 | 23 | from tensorflow.python.tpu import tpu_function # pylint:disable=g-direct-tensorflow-import |
25 | 24 | # pylint: disable=logging-format-interpolation |
26 | 25 |
|
@@ -643,83 +642,12 @@ def build_model_with_precision(pp, mm, ii, *args, **kwargs): |
643 | 642 | return outputs |
644 | 643 |
|
645 | 644 |
|
646 | | -def _recompute_grad(f): |
647 | | - """An eager-compatible version of recompute_grad. |
648 | | -
|
649 | | - For f(*args, **kwargs), this supports gradients with respect to args or |
650 | | - kwargs, but kwargs are currently only supported in eager-mode. |
651 | | - Note that for keras layer and model objects, this is handled automatically. |
652 | | -
|
653 | | - Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not |
654 | | - be able to access the member variables of that object, because `g` returns |
655 | | - through the wrapper function `inner`. When recomputing gradients through |
656 | | - objects that inherit from tf2, we suggest keeping a reference to the |
657 | | - underlying object around for the purpose of accessing these variables. |
658 | | -
|
659 | | - Args: |
660 | | - f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs. |
661 | | -
|
662 | | - Returns: |
663 | | - A function `g` that wraps `f`, but which recomputes `f` on the backwards |
664 | | - pass of a gradient call. |
665 | | - """ |
666 | | - |
667 | | - @tf.custom_gradient |
668 | | - def inner(*args, **kwargs): |
669 | | - """Inner function closure for calculating gradients.""" |
670 | | - current_var_scope = tf.get_variable_scope() |
671 | | - with tape_lib.stop_recording(): |
672 | | - result = f(*args, **kwargs) |
673 | | - |
674 | | - def grad_wrapper(*wrapper_args, **grad_kwargs): |
675 | | - """Wrapper function to accomodate lack of kwargs in graph mode decorator.""" |
676 | | - |
677 | | - @tf.custom_gradient |
678 | | - def inner_recompute_grad(*dresult): |
679 | | - """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" |
680 | | - # Gradient calculation for reverse mode autodiff. |
681 | | - variables = grad_kwargs.get('variables') |
682 | | - with tf.GradientTape() as t: |
683 | | - id_args = tf.nest.map_structure(tf.identity, args) |
684 | | - t.watch(id_args) |
685 | | - if variables is not None: |
686 | | - t.watch(variables) |
687 | | - with tf.control_dependencies(dresult): |
688 | | - with tf.variable_scope(current_var_scope): |
689 | | - result = f(*id_args, **kwargs) |
690 | | - kw_vars = [] |
691 | | - if variables is not None: |
692 | | - kw_vars = list(variables) |
693 | | - grads = t.gradient( |
694 | | - result, |
695 | | - list(id_args) + kw_vars, |
696 | | - output_gradients=dresult, |
697 | | - unconnected_gradients=tf.UnconnectedGradients.ZERO) |
698 | | - |
699 | | - def transpose(*t_args, **t_kwargs): |
700 | | - """Gradient function calculation for forward mode autodiff.""" |
701 | | - # Just throw an error since gradients / activations are not stored on |
702 | | - # tape for recompute. |
703 | | - raise NotImplementedError( |
704 | | - 'recompute_grad tried to transpose grad of {}. ' |
705 | | - 'Consider not using recompute_grad in forward mode' |
706 | | - 'autodiff'.format(f.__name__)) |
707 | | - |
708 | | - return (grads[:len(id_args)], grads[len(id_args):]), transpose |
709 | | - |
710 | | - return inner_recompute_grad(*wrapper_args) |
711 | | - |
712 | | - return result, grad_wrapper |
713 | | - |
714 | | - return inner |
715 | | - |
716 | | - |
717 | 645 | def recompute_grad(recompute=False): |
718 | 646 | """Decorator determine whether use gradient checkpoint.""" |
719 | 647 |
|
720 | 648 | def _wrapper(f): |
721 | 649 | if recompute: |
722 | | - return _recompute_grad(f) |
| 650 | + return tf.recompute_grad(f) |
723 | 651 | return f |
724 | 652 |
|
725 | 653 | return _wrapper |
0 commit comments