Skip to content

Commit 3e1788b

Browse files
committed
🤘 Support Multi-GPU gradient Accumulate for trainer.
1 parent 518ee49 commit 3e1788b

File tree

3 files changed

+294
-72
lines changed

3 files changed

+294
-72
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from tensorflow_tts.optimizers.adamweightdecay import AdamWeightDecay, WarmUp
2+
from tensorflow_tts.optimizers.gradient_accumulate import GradientAccumulator
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Gradient Accummlate for training TF2 custom training loop.
2+
Copy from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py.
3+
"""
4+
5+
6+
import re
7+
8+
import tensorflow as tf
9+
10+
11+
class GradientAccumulator(object):
12+
"""Gradient accumulation utility.
13+
When used with a distribution strategy, the accumulator should be called in a
14+
replica context. Gradients will be accumulated locally on each replica and
15+
without synchronization. Users should then call ``.gradients``, scale the
16+
gradients if required, and pass the result to ``apply_gradients``.
17+
"""
18+
19+
# We use the ON_READ synchronization policy so that no synchronization is
20+
# performed on assignment. To get the value, we call .value() which returns the
21+
# value on the current replica without synchronization.
22+
23+
def __init__(self):
24+
"""Initializes the accumulator."""
25+
self._gradients = []
26+
self._accum_steps = None
27+
28+
@property
29+
def step(self):
30+
"""Number of accumulated steps."""
31+
if self._accum_steps is None:
32+
self._accum_steps = tf.Variable(
33+
tf.constant(0, dtype=tf.int64),
34+
trainable=False,
35+
synchronization=tf.VariableSynchronization.ON_READ,
36+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
37+
)
38+
39+
return self._accum_steps.value()
40+
41+
@property
42+
def gradients(self):
43+
"""The accumulated gradients on the current replica."""
44+
if not self._gradients:
45+
raise ValueError(
46+
"The accumulator should be called first to initialize the gradients"
47+
)
48+
return list(
49+
gradient.value() if gradient is not None else gradient
50+
for gradient in self._gradients
51+
)
52+
53+
def __call__(self, gradients):
54+
"""Accumulates :obj:`gradients` on the current replica."""
55+
if not self._gradients:
56+
_ = self.step # Create the step variable.
57+
self._gradients.extend(
58+
[
59+
tf.Variable(
60+
tf.zeros_like(gradient),
61+
trainable=False,
62+
synchronization=tf.VariableSynchronization.ON_READ,
63+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
64+
)
65+
if gradient is not None
66+
else gradient
67+
for gradient in gradients
68+
]
69+
)
70+
if len(gradients) != len(self._gradients):
71+
raise ValueError(
72+
"Expected %s gradients, but got %d"
73+
% (len(self._gradients), len(gradients))
74+
)
75+
76+
for accum_gradient, gradient in zip(self._gradients, gradients):
77+
if accum_gradient is not None and gradient is not None:
78+
accum_gradient.assign_add(gradient, read_value=False)
79+
80+
self._accum_steps.assign_add(1)
81+
82+
def reset(self):
83+
"""Resets the accumulated gradients on the current replica."""
84+
if not self._gradients:
85+
return
86+
self._accum_steps.assign(0)
87+
for gradient in self._gradients:
88+
if gradient is not None:
89+
gradient.assign(tf.zeros_like(gradient), read_value=False)

0 commit comments

Comments
 (0)