Skip to content

Commit a7f4ec2

Browse files
committed
remove duplicate function
1 parent 35ca5c7 commit a7f4ec2

File tree

1 file changed

+1
-73
lines changed

1 file changed

+1
-73
lines changed

efficientdet/utils.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import numpy as np
2121
import tensorflow.compat.v1 as tf
2222
import tensorflow.compat.v2 as tf2
23-
from tensorflow.python.eager import tape as tape_lib # pylint:disable=g-direct-tensorflow-import
2423
from tensorflow.python.tpu import tpu_function # pylint:disable=g-direct-tensorflow-import
2524
# pylint: disable=logging-format-interpolation
2625

@@ -643,83 +642,12 @@ def build_model_with_precision(pp, mm, ii, *args, **kwargs):
643642
return outputs
644643

645644

646-
def _recompute_grad(f):
647-
"""An eager-compatible version of recompute_grad.
648-
649-
For f(*args, **kwargs), this supports gradients with respect to args or
650-
kwargs, but kwargs are currently only supported in eager-mode.
651-
Note that for keras layer and model objects, this is handled automatically.
652-
653-
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
654-
be able to access the member variables of that object, because `g` returns
655-
through the wrapper function `inner`. When recomputing gradients through
656-
objects that inherit from tf2, we suggest keeping a reference to the
657-
underlying object around for the purpose of accessing these variables.
658-
659-
Args:
660-
f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
661-
662-
Returns:
663-
A function `g` that wraps `f`, but which recomputes `f` on the backwards
664-
pass of a gradient call.
665-
"""
666-
667-
@tf.custom_gradient
668-
def inner(*args, **kwargs):
669-
"""Inner function closure for calculating gradients."""
670-
current_var_scope = tf.get_variable_scope()
671-
with tape_lib.stop_recording():
672-
result = f(*args, **kwargs)
673-
674-
def grad_wrapper(*wrapper_args, **grad_kwargs):
675-
"""Wrapper function to accomodate lack of kwargs in graph mode decorator."""
676-
677-
@tf.custom_gradient
678-
def inner_recompute_grad(*dresult):
679-
"""Nested custom gradient function for computing grads in reverse and forward mode autodiff."""
680-
# Gradient calculation for reverse mode autodiff.
681-
variables = grad_kwargs.get('variables')
682-
with tf.GradientTape() as t:
683-
id_args = tf.nest.map_structure(tf.identity, args)
684-
t.watch(id_args)
685-
if variables is not None:
686-
t.watch(variables)
687-
with tf.control_dependencies(dresult):
688-
with tf.variable_scope(current_var_scope):
689-
result = f(*id_args, **kwargs)
690-
kw_vars = []
691-
if variables is not None:
692-
kw_vars = list(variables)
693-
grads = t.gradient(
694-
result,
695-
list(id_args) + kw_vars,
696-
output_gradients=dresult,
697-
unconnected_gradients=tf.UnconnectedGradients.ZERO)
698-
699-
def transpose(*t_args, **t_kwargs):
700-
"""Gradient function calculation for forward mode autodiff."""
701-
# Just throw an error since gradients / activations are not stored on
702-
# tape for recompute.
703-
raise NotImplementedError(
704-
'recompute_grad tried to transpose grad of {}. '
705-
'Consider not using recompute_grad in forward mode'
706-
'autodiff'.format(f.__name__))
707-
708-
return (grads[:len(id_args)], grads[len(id_args):]), transpose
709-
710-
return inner_recompute_grad(*wrapper_args)
711-
712-
return result, grad_wrapper
713-
714-
return inner
715-
716-
717645
def recompute_grad(recompute=False):
718646
"""Decorator determine whether use gradient checkpoint."""
719647

720648
def _wrapper(f):
721649
if recompute:
722-
return _recompute_grad(f)
650+
return tf.recompute_grad(f)
723651
return f
724652

725653
return _wrapper

0 commit comments

Comments
 (0)