Skip to content

Commit 9144c7d

Browse files
committed
✍️ update metrics for training using builtin keras
1 parent b2bfd92 commit 9144c7d

File tree

5 files changed

+3878
-33
lines changed

5 files changed

+3878
-33
lines changed

tensorflow_asr/losses/keras/ctc_losses.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
# limitations under the License.
1414

1515
import tensorflow as tf
16-
from tensorflow.python.keras.utils import losses_utils
17-
1816
from .. import ctc_loss
1917

2018

2119
class CtcLoss(tf.keras.losses.Loss):
22-
def __init__(self, blank=0, global_batch_size=None, reduction=losses_utils.ReductionV2.NONE, name=None):
20+
def __init__(self, blank=0, global_batch_size=None, reduction=tf.keras.losses.Reduction.NONE, name=None):
2321
super(CtcLoss, self).__init__(reduction=reduction, name=name)
2422
self.blank = blank
2523
self.global_batch_size = global_batch_size

tensorflow_asr/losses/keras/rnnt_losses.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
# limitations under the License.
1414

1515
import tensorflow as tf
16-
from tensorflow.python.keras.utils import losses_utils
17-
1816
from .. import rnnt_loss
1917

2018

2119
class RnntLoss(tf.keras.losses.Loss):
22-
def __init__(self, blank=0, global_batch_size=None, reduction=losses_utils.ReductionV2.NONE, name=None):
20+
def __init__(self, blank=0, global_batch_size=None, reduction=tf.keras.losses.Reduction.NONE, name=None):
2321
super(RnntLoss, self).__init__(reduction=reduction, name=name)
2422
self.blank = blank
2523
self.global_batch_size = global_batch_size

tensorflow_asr/models/keras/ctc.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,25 @@
2222

2323
class CtcModel(BaseCtcModel):
2424
""" Keras CTC Model Warper """
25+
@property
26+
def metrics(self):
27+
return [self.loss_metric]
2528

26-
def compile(self, optimizer, global_batch_size, blank=0, use_loss_scale=False,
27-
loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs):
29+
def compile(self, optimizer, global_batch_size, blank=0, use_loss_scale=False, run_eagerly=None, **kwargs):
2830
loss = CtcLoss(blank=blank, global_batch_size=global_batch_size)
2931
self.use_loss_scale = use_loss_scale
3032
if self.use_loss_scale:
31-
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic')
32-
super(CtcModel, self).compile(
33-
optimizer=optimizer, loss=loss,
34-
loss_weights=loss_weights, weighted_metrics=weighted_metrics,
35-
run_eagerly=run_eagerly,
36-
**kwargs
37-
)
33+
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), "dynamic")
34+
self.loss_metric = tf.keras.metrics.Mean(name="ctc_loss", dtype=tf.float32)
35+
super(CtcModel, self).compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs)
3836

3937
def train_step(self, batch):
4038
x, y_true = batch
4139
with tf.GradientTape() as tape:
42-
logit = self(x['input'], training=True)
40+
logit = self(x["input"], training=True)
4341
y_pred = {
44-
'logit': logit,
45-
'logit_length': get_reduced_length(x['input_length'], self.time_reduction_factor)
42+
"logit": logit,
43+
"logit_length": get_reduced_length(x["input_length"], self.time_reduction_factor)
4644
}
4745
loss = self.loss(y_true, y_pred)
4846
if self.use_loss_scale:
@@ -53,14 +51,16 @@ def train_step(self, batch):
5351
else:
5452
gradients = tape.gradient(loss, self.trainable_weights)
5553
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
56-
return {"ctc_loss": loss}
54+
self.loss_metric.update_state(loss)
55+
return {m.name: m.result() for m in self.metrics}
5756

5857
def test_step(self, batch):
5958
x, y_true = batch
6059
logit = self(x, training=False)
6160
y_pred = {
62-
'logit': logit,
63-
'logit_length': get_reduced_length(x['input_length'], self.time_reduction_factor)
61+
"logit": logit,
62+
"logit_length": get_reduced_length(x["input_length"], self.time_reduction_factor)
6463
}
6564
loss = self.loss(y_true, y_pred)
66-
return {"ctc_loss": loss}
65+
self.loss_metric.update_state(loss)
66+
return {m.name: m.result() for m in self.metrics}

tensorflow_asr/models/keras/transducer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
class Transducer(BaseTransducer):
2525
""" Keras Transducer Model Warper """
26+
@property
27+
def metrics(self):
28+
return [self.loss_metric]
2629

2730
def _build(self, input_shape, prediction_shape=[None], batch_size=None):
2831
inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32)
@@ -48,18 +51,13 @@ def call(self, inputs, training=False, **kwargs):
4851
"logit_length": get_reduced_length(inputs["input_length"], self.time_reduction_factor)
4952
}
5053

51-
def compile(self, optimizer, global_batch_size, blank=0, use_loss_scale=False,
52-
loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs):
54+
def compile(self, optimizer, global_batch_size, blank=0, use_loss_scale=False, run_eagerly=None, **kwargs):
5355
loss = RnntLoss(blank=blank, global_batch_size=global_batch_size)
5456
self.use_loss_scale = use_loss_scale
5557
if self.use_loss_scale:
56-
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic')
57-
super(Transducer, self).compile(
58-
optimizer=optimizer, loss=loss,
59-
loss_weights=loss_weights, weighted_metrics=weighted_metrics,
60-
run_eagerly=run_eagerly,
61-
**kwargs
62-
)
58+
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), "dynamic")
59+
self.loss_metric = tf.keras.metrics.Mean(name="rnnt_loss", dtype=tf.float32)
60+
super(Transducer, self).compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs)
6361

6462
def train_step(self, batch):
6563
x, y_true = batch
@@ -79,7 +77,8 @@ def train_step(self, batch):
7977
else:
8078
gradients = tape.gradient(loss, self.trainable_weights)
8179
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
82-
return {"rnnt_loss": loss}
80+
self.loss_metric.update_state(loss)
81+
return {m.name: m.result() for m in self.metrics}
8382

8483
def test_step(self, batch):
8584
x, y_true = batch
@@ -90,4 +89,5 @@ def test_step(self, batch):
9089
"prediction_length": x["prediction_length"],
9190
}, training=False)
9291
loss = self.loss(y_true, y_pred)
93-
return {"rnnt_loss": loss}
92+
self.loss_metric.update_state(loss)
93+
return {m.name: m.result() for m in self.metrics}

0 commit comments

Comments
 (0)