Skip to content

Commit 73a9308

Browse files
committed
🏅 Fix OOM error when evaluating Tacotron2 (#416)
1 parent 93396e7 commit 73a9308

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

examples/tacotron2/train_tacotron2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ def _one_step_evaluate_per_replica(self, batch):
124124

125125
self.update_eval_metrics(dict_metrics_losses)
126126

127+
def _one_step_predict_per_replica(self, batch):
128+
"""One step predict per GPU
129+
130+
Tacotron-2 used teacher-forcing when training and evaluation.
131+
So we need pass `training=True` for inference step.
132+
133+
"""
134+
outputs = self._model(**batch, training=True)
135+
return outputs
136+
127137
def compute_per_example_losses(self, batch, outputs):
128138
"""Compute per example losses and return dict_metrics_losses
129139
Note that all element of the loss MUST has a shape [batch_size] and

0 commit comments

Comments
 (0)