Skip to content

Commit 4dbbb17

Browse files
committed
✍️ update testing functions and scripts
1 parent 4d07e9c commit 4dbbb17

File tree

4 files changed

+76
-14
lines changed

4 files changed

+76
-14
lines changed

examples/conformer/test.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
from tqdm import tqdm
1617
import argparse
1718
from tensorflow_asr.utils import env_util, file_util
1819

@@ -58,6 +59,7 @@
5859
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
5960
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer, CharFeaturizer
6061
from tensorflow_asr.models.transducer.conformer import Conformer
62+
from tensorflow_asr.utils import app_util
6163

6264
config = Config(args.config)
6365
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
@@ -97,13 +99,20 @@
9799
batch_size = args.bs or config.learning_config.running_config.batch_size
98100
test_data_loader = test_dataset.create(batch_size)
99101

100-
results = conformer.predict(test_data_loader)
101-
102102
with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath:
103-
print(f"Saving result to {args.output} ...")
104-
with open(filepath, "w") as openfile:
105-
openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n")
106-
for i, entry in test_dataset.entries:
107-
groundtruth, greedy, beamsearch = results[i]
108-
path, duration, _ = entry
109-
openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n")
103+
overwrite = False
104+
if tf.io.gfile.exists(filepath):
105+
overwrite = input("Overwrite existing result file? (y/n): ").lower() == "y"
106+
if overwrite:
107+
results = conformer.predict(test_data_loader, verbose=1)
108+
print(f"Saving result to {args.output} ...")
109+
with open(filepath, "w") as openfile:
110+
openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n")
111+
progbar = tqdm(total=test_dataset.total_steps, unit="batch")
112+
for i, pred in enumerate(results):
113+
groundtruth, greedy, beamsearch = [x.decode('utf-8') for x in pred]
114+
path, duration, _ = test_dataset.entries[i]
115+
openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n")
116+
progbar.update(1)
117+
progbar.close()
118+
app_util.evaluate_results(filepath)

tensorflow_asr/metrics/error_rates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ def update_state(self, decode: tf.Tensor, target: tf.Tensor):
3030
self.denominator.assign_add(d)
3131

3232
def result(self):
33-
return tf.math.divide_no_nan(self.numerator, self.denominator) * 100
33+
return tf.math.divide_no_nan(self.numerator, self.denominator)

tensorflow_asr/models/base_model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020

2121
class BaseModel(tf.keras.Model):
22+
def __init__(self, *args, **kwargs):
23+
super().__init__(*args, **kwargs)
24+
self._metrics = {}
25+
2226
def save(self,
2327
filepath,
2428
overwrite=True,
@@ -66,7 +70,10 @@ def load_weights(self,
6670

6771
@property
6872
def metrics(self):
69-
return [self.loss_metric]
73+
return self._metrics.values()
74+
75+
def add_metric(self, metric: tf.keras.metrics.Metric):
76+
self._metrics.append({metric.name: metric})
7077

7178
def _build(self, *args, **kwargs):
7279
raise NotImplementedError()
@@ -76,7 +83,8 @@ def compile(self, loss, optimizer, run_eagerly=None, **kwargs):
7683
if not env_util.has_tpu():
7784
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), "dynamic")
7885
self.use_loss_scale = True
79-
self.loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
86+
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
87+
self._metrics = {loss_metric.name: loss_metric}
8088
super().compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs)
8189

8290
# -------------------------------- STEP FUNCTIONS -------------------------------------
@@ -92,14 +100,14 @@ def train_step(self, batch):
92100
if self.use_loss_scale:
93101
gradients = self.optimizer.get_unscaled_gradients(gradients)
94102
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
95-
self.loss_metric.update_state(loss)
103+
self._metrics["loss"].update_state(loss)
96104
return {m.name: m.result() for m in self.metrics}
97105

98106
def test_step(self, batch):
99107
inputs, y_true = batch
100108
y_pred = self(inputs, training=False)
101109
loss = self.loss(y_true, y_pred)
102-
self.loss_metric.update_state(loss)
110+
self._metrics["loss"].update_state(loss)
103111
return {m.name: m.result() for m in self.metrics}
104112

105113
def predict_step(self, batch):

tensorflow_asr/utils/app_util.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from tqdm import tqdm
16+
import tensorflow as tf
17+
18+
from .metric_util import wer, cer
19+
from ..metrics.error_rates import ErrorRate
20+
from .file_util import read_file
21+
22+
23+
def evaluate_results(filepath: str):
24+
print(f"Evaluating result from {filepath} ...")
25+
metrics = {
26+
"greedy_wer": ErrorRate(wer, name="greedy_wer", dtype=tf.float32),
27+
"greedy_cer": ErrorRate(cer, name="greedy_cer", dtype=tf.float32),
28+
"beamsearch_wer": ErrorRate(wer, name="beamsearch_wer", dtype=tf.float32),
29+
"beamsearch_cer": ErrorRate(cer, name="beamsearch_cer", dtype=tf.float32)
30+
}
31+
with read_file(filepath) as path:
32+
with open(path, "r", encoding="utf-8") as openfile:
33+
lines = openfile.read().splitlines()
34+
lines = lines[1:] # skip header
35+
for eachline in tqdm(lines):
36+
_, _, groundtruth, greedy, beamsearch = eachline.split("\t")
37+
groundtruth = tf.convert_to_tensor([groundtruth], dtype=tf.string)
38+
greedy = tf.convert_to_tensor([greedy], dtype=tf.string)
39+
beamsearch = tf.convert_to_tensor([beamsearch], dtype=tf.string)
40+
metrics["greedy_wer"].update_state(decode=greedy, target=groundtruth)
41+
metrics["greedy_cer"].update_state(decode=greedy, target=groundtruth)
42+
metrics["beamsearch_wer"].update_state(decode=beamsearch, target=groundtruth)
43+
metrics["beamsearch_cer"].update_state(decode=beamsearch, target=groundtruth)
44+
for key, value in metrics.items():
45+
print(f"{key}: {value.result().numpy()}")

0 commit comments

Comments
 (0)