Skip to content

Commit 263855e

Browse files
committed
⚡ Supported Gradients Accumulation
1 parent 26193fb commit 263855e

File tree

10 files changed

+62
-50
lines changed

10 files changed

+62
-50
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as
1919

2020
## What's New?
2121

22+
- (11/14/2020) Supported Gradient Accumulation for Training in Larger Batch Size
2223
- (11/3/2020) Reduce differences between `librosa.stft` and `tf.signal.stft`
2324
- (10/31/2020) Update DeepSpeech2 and Supported Jasper [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288)
2425
- (10/18/2020) Supported Streaming Transducer [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621)

examples/conformer/train_ga_conformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
parser.add_argument("--ebs", type=int, default=None,
4242
help="Evaluation batch size per replica")
4343

44+
parser.add_argument("--acs", type=int, default=None,
45+
help="Train accumulation steps")
46+
4447
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4548
help="Devices' ids to apply distributed training")
4649

@@ -125,4 +128,5 @@
125128
conformer_trainer.compile(model=conformer, optimizer=optimizer,
126129
max_to_keep=args.max_ckpts)
127130

128-
conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
131+
conformer_trainer.fit(train_dataset, eval_dataset,
132+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

examples/conformer/train_ga_subword_conformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
parser.add_argument("--ebs", type=int, default=None,
4242
help="Evaluation batch size per replica")
4343

44+
parser.add_argument("--acs", type=int, default=None,
45+
help="Train accumulation steps")
46+
4447
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4548
help="Devices' ids to apply distributed training")
4649

@@ -141,4 +144,5 @@
141144
conformer_trainer.compile(model=conformer, optimizer=optimizer,
142145
max_to_keep=args.max_ckpts)
143146

144-
conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
147+
conformer_trainer.fit(train_dataset, eval_dataset,
148+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

examples/streaming_transducer/train_ga_streaming_transducer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
parser.add_argument("--ebs", type=int, default=None,
4141
help="Evaluation batch size per replica")
4242

43+
parser.add_argument("--acs", type=int, default=None,
44+
help="Train accumulation steps")
45+
4346
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4447
help="Devices' ids to apply distributed training")
4548

@@ -116,4 +119,5 @@
116119
streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer,
117120
max_to_keep=args.max_ckpts)
118121

119-
streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
122+
streaming_transducer_trainer.fit(train_dataset, eval_dataset,
123+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

examples/streaming_transducer/train_ga_subword_streaming_transducer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
parser.add_argument("--ebs", type=int, default=None,
4141
help="Evaluation batch size per replica")
4242

43+
parser.add_argument("--acs", type=int, default=None,
44+
help="Train accumulation steps")
45+
4346
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4447
help="Devices' ids to apply distributed training")
4548

@@ -132,4 +135,5 @@
132135
streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer,
133136
max_to_keep=args.max_ckpts)
134137

135-
streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
138+
streaming_transducer_trainer.fit(train_dataset, eval_dataset,
139+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
long_description = fh.read()
1919

2020
requirements = [
21-
# "tensorflow>=2.3.0",
21+
"tensorflow>=2.3.0",
2222
"tensorflow-datasets>=3.2.1,<4.0.0",
2323
"tensorflow-addons>=0.10.0",
2424
"setuptools>=47.1.1",
@@ -38,7 +38,7 @@
3838

3939
setuptools.setup(
4040
name="TensorFlowASR",
41-
version="0.2.10",
41+
version="0.3.0",
4242
author="Huy Le Nguyen",
4343
author_email="[email protected]",
4444
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/optimizers/accumulation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ def __init__(self, trainable_variables):
2020
self.gradients = [
2121
tf.Variable(
2222
tf.zeros_like(g),
23+
trainable=False,
2324
synchronization=tf.VariableSynchronization.ON_READ
2425
) for g in trainable_variables
2526
]
2627

2728
def reset(self):
28-
for g in self.gradients: g.assign(tf.zeros_like(g))
29+
for i, g in enumerate(self.gradients):
30+
self.gradients[i].assign(tf.zeros_like(g))
2931

3032
def accumulate(self, step_gradients):
3133
for i, g in enumerate(step_gradients):

tensorflow_asr/runners/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def save_from_checkpoint(func,
2828
max_to_keep: number of checkpoints to keep
2929
**kwargs: contains built models, optimizers
3030
"""
31-
steps = tf.Variable(0, dtype=tf.int64) # Step must be int64
32-
epochs = tf.Variable(1)
31+
steps = tf.Variable(0, trainable=False, dtype=tf.int64) # Step must be int64
32+
epochs = tf.Variable(1, trainable=False)
3333
checkpoint_dir = os.path.join(outdir, "checkpoints")
3434
if not os.path.exists(checkpoint_dir):
3535
raise ValueError(f"checkpoint directory not found: {checkpoint_dir}")

tensorflow_asr/runners/base_runners.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def __init__(self,
7272
super(BaseTrainer, self).__init__(config)
7373
self.set_strategy(strategy)
7474
# Steps and Epochs start from 0
75-
self.steps = tf.Variable(0, dtype=tf.int64) # Step must be int64 to use tf.summary
75+
# Step must be int64 to use tf.summary
76+
self.steps = tf.Variable(0, trainable=False, dtype=tf.int64)
7677
self.train_steps_per_epoch = None
7778
self.eval_steps_per_epoch = None
7879
# Dataset
@@ -120,13 +121,14 @@ def set_train_data_loader(self, train_dataset, train_bs=None, train_acs=None):
120121
self.config.batch_size = train_bs # Update batch size fed from arguments
121122

122123
if not train_acs: train_acs = self.config.accumulation_steps
123-
assert train_bs % train_acs == 0, "Batch size must be a multiple of Accumulation Steps"
124-
self.accumulation_bs = train_bs // train_acs
125124
self.config.accumulation_steps = train_acs # update accum steps fed from arguments
126125

127126
self.train_data = train_dataset.create(self.global_batch_size)
128127
self.train_data_loader = self.strategy.experimental_distribute_dataset(self.train_data)
129-
self.train_steps_per_epoch = train_dataset.total_steps
128+
if hasattr(self, "accumulation"):
129+
self.train_steps_per_epoch = train_dataset.total_steps // self.config.accumulation_steps
130+
else:
131+
self.train_steps_per_epoch = train_dataset.total_steps
130132

131133
def set_eval_data_loader(self, eval_dataset, eval_bs=None):
132134
""" Set eval data loader (MUST).

tensorflow_asr/runners/transducer_runners.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -90,48 +90,39 @@ def compile(self,
9090
class TransducerTrainerGA(TransducerTrainer):
9191
""" Transducer Trainer that uses Gradients Accumulation """
9292

93-
@tf.function(experimental_relax_shapes=True)
94-
def _train_step(self, batch):
95-
_, bfeatures, binput_length, blabels, blabel_length, bpred_inp = batch
96-
93+
@tf.function
94+
def _train_function(self, iterator):
95+
for _ in range(self.config.accumulation_steps):
96+
batch = next(iterator)
97+
self.strategy.run(self._train_step, args=(batch,))
98+
self.strategy.run(self._apply_gradients, args=())
99+
100+
@tf.function
101+
def _apply_gradients(self):
102+
self.optimizer.apply_gradients(
103+
zip(self.accumulation.gradients, self.model.trainable_variables))
97104
self.accumulation.reset()
98105

99-
for accum_step in range(self.config.accumulation_steps):
106+
@tf.function(experimental_relax_shapes=True)
107+
def _train_step(self, batch):
108+
_, features, input_length, labels, label_length, pred_inp = batch
100109

101-
indices = tf.expand_dims(
102-
tf.range(
103-
accum_step * self.accumulation_bs,
104-
(accum_step + 1) * self.accumulation_bs,
105-
dtype=tf.int32
106-
),
107-
axis=-1
110+
with tf.GradientTape() as tape:
111+
logits = self.model([features, pred_inp], training=True)
112+
tape.watch(logits)
113+
per_train_loss = rnnt_loss(
114+
logits=logits, labels=labels, label_length=label_length,
115+
logit_length=(input_length // self.model.time_reduction_factor),
116+
blank=self.text_featurizer.blank
117+
)
118+
train_loss = tf.nn.compute_average_loss(
119+
per_train_loss,
120+
global_batch_size=self.global_batch_size
108121
)
109122

110-
features = tf.gather_nd(bfeatures, indices)
111-
input_length = tf.gather_nd(binput_length, indices)
112-
labels = tf.gather_nd(blabels, indices)
113-
label_length = tf.gather_nd(blabel_length, indices)
114-
pred_inp = tf.gather_nd(bpred_inp, indices)
115-
116-
with tf.GradientTape() as tape:
117-
logits = self.model([features, pred_inp], training=True)
118-
tape.watch(logits)
119-
per_train_loss = rnnt_loss(
120-
logits=logits, labels=labels, label_length=label_length,
121-
logit_length=(input_length // self.model.time_reduction_factor),
122-
blank=self.text_featurizer.blank
123-
)
124-
train_loss = tf.nn.compute_average_loss(
125-
per_train_loss,
126-
global_batch_size=self.global_batch_size
127-
)
128-
129-
step_gradients = tape.gradient(train_loss, self.model.trainable_variables)
130-
self.accumulation.accumulate(step_gradients)
131-
self.train_metrics["transducer_loss"].update_state(per_train_loss)
132-
133-
self.optimizer.apply_gradients(
134-
zip(self.accumulation.gradients, self.model.trainable_variables))
123+
gradients = tape.gradient(train_loss, self.model.trainable_variables)
124+
self.accumulation.accumulate(gradients)
125+
self.train_metrics["transducer_loss"].update_state(per_train_loss)
135126

136127
def compile(self,
137128
model: Transducer,

0 commit comments

Comments
 (0)