Skip to content

Commit 4d07e9c

Browse files
committed
✍️ update testing script
1 parent d86d621 commit 4d07e9c

File tree

9 files changed

+91
-113
lines changed

9 files changed

+91
-113
lines changed

examples/conformer/test.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
import os
1616
import argparse
17-
from tensorflow_asr.utils import setup_environment, setup_devices
17+
from tensorflow_asr.utils import env_util, file_util
1818

19-
setup_environment()
19+
env_util.setup_environment()
2020
import tensorflow as tf
2121

2222
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
@@ -33,65 +33,77 @@
3333

3434
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
3535

36+
parser.add_argument("--bs", type=int, default=None, help="Test batch size")
37+
3638
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
3739

40+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
41+
3842
parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")
3943

4044
parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
4145

42-
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
43-
44-
parser.add_argument("--output_name", type=str, default="test", help="Result filename name prefix")
46+
parser.add_argument("--output", type=str, default="test.tsv", help="Result filepath")
4547

4648
args = parser.parse_args()
4749

50+
assert args.saved
51+
4852
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
4953

50-
setup_devices([args.device], cpu=args.cpu)
54+
env_util.setup_devices([args.device], cpu=args.cpu)
5155

5256
from tensorflow_asr.configs.config import Config
5357
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
5458
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
55-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
56-
from tensorflow_asr.runners.base_runners import BaseTester
57-
from tensorflow_asr.models.conformer import Conformer
59+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer, CharFeaturizer
60+
from tensorflow_asr.models.transducer.conformer import Conformer
5861

5962
config = Config(args.config)
6063
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
6164

6265
if args.sentence_piece:
63-
print("Loading SentencePiece model ...")
64-
text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords)
65-
elif args.subwords and os.path.exists(args.subwords):
66-
print("Loading subwords ...")
67-
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
66+
print("Use SentencePiece ...")
67+
text_featurizer = SentencePieceFeaturizer(config.decoder_config)
68+
elif args.subwords:
69+
print("Use subwords ...")
70+
text_featurizer = SubwordFeaturizer(config.decoder_config)
6871
else:
69-
raise ValueError("subwords must be set")
72+
print("Use characters ...")
73+
text_featurizer = CharFeaturizer(config.decoder_config)
7074

7175
tf.random.set_seed(0)
72-
assert args.saved
7376

7477
if args.tfrecords:
7578
test_dataset = ASRTFRecordDataset(
76-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
79+
speech_featurizer=speech_featurizer,
80+
text_featurizer=text_featurizer,
7781
**vars(config.learning_config.test_dataset_config)
7882
)
7983
else:
8084
test_dataset = ASRSliceDataset(
81-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
85+
speech_featurizer=speech_featurizer,
86+
text_featurizer=text_featurizer,
8287
**vars(config.learning_config.test_dataset_config)
8388
)
8489

8590
# build model
8691
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
8792
conformer._build(speech_featurizer.shape)
8893
conformer.load_weights(args.saved)
89-
conformer.summary(line_length=120)
94+
conformer.summary(line_length=100)
9095
conformer.add_featurizers(speech_featurizer, text_featurizer)
9196

92-
conformer_tester = BaseTester(
93-
config=config.learning_config.running_config,
94-
output_name=args.output_name
95-
)
96-
conformer_tester.compile(conformer)
97-
conformer_tester.run(test_dataset)
97+
batch_size = args.bs or config.learning_config.running_config.batch_size
98+
test_data_loader = test_dataset.create(batch_size)
99+
100+
results = conformer.predict(test_data_loader)
101+
102+
with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath:
103+
print(f"Saving result to {args.output} ...")
104+
with open(filepath, "w") as openfile:
105+
openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n")
106+
for i, entry in test_dataset.entries:
107+
groundtruth, greedy, beamsearch = results[i]
108+
path, duration, _ = entry
109+
openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n")

tensorflow_asr/losses/ctc_loss.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ def __init__(self, blank=0, global_batch_size=None, name=None):
2121
self.global_batch_size = global_batch_size
2222

2323
def call(self, y_true, y_pred):
24-
logits, logits_length = y_pred.values()
25-
labels, labels_length = y_true.values()
2624
loss = ctc_loss(
27-
y_pred=logits,
28-
input_length=logits_length,
29-
y_true=labels,
30-
label_length=labels_length,
25+
y_pred=y_pred["logits"],
26+
input_length=y_pred["logits_length"],
27+
y_true=y_true["labels"],
28+
label_length=y_true["labels_length"],
3129
blank=self.blank,
3230
name=self.name
3331
)

tensorflow_asr/losses/rnnt_loss.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,11 @@ def __init__(self, blank=0, global_batch_size=None, name=None):
3737
self.global_batch_size = global_batch_size
3838

3939
def call(self, y_true, y_pred):
40-
logits, logits_length = y_pred.values()
41-
labels, labels_length = y_true.values()
4240
loss = rnnt_loss(
43-
logits=logits,
44-
logit_length=logits_length,
45-
labels=labels,
46-
label_length=labels_length,
41+
logits=y_pred["logits"],
42+
logit_length=y_pred["logits_length"],
43+
labels=y_true["labels"],
44+
label_length=y_true["labels_length"],
4745
blank=self.blank,
4846
name=self.name
4947
)

tensorflow_asr/models/base_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,12 @@ def predict_step(self, batch):
111111
[tf.Tensor]: stacked tensor of shape [B, 3] with each row is the text [truth, greedy, beam_search]
112112
"""
113113
inputs, y_true = batch
114-
labels = self.text_featurizer.iextract(y_true)
114+
labels = self.text_featurizer.iextract(y_true["labels"])
115115
greedy_decoding = self.recognize(inputs)
116-
beam_search_decoding = self.recognize_beam(inputs)
116+
if self.text_featurizer.decoder_config.beam_width == 0:
117+
beam_search_decoding = tf.map_fn(lambda _: tf.convert_to_tensor("", dtype=tf.string), labels)
118+
else:
119+
beam_search_decoding = self.recognize_beam(inputs)
117120
return tf.stack([labels, greedy_decoding, beam_search_decoding], axis=-1)
118121

119122
def recognize(self, features, input_lengths, **kwargs):

tensorflow_asr/models/ctc/ctc.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Union
15+
from typing import Dict, Union
1616
import numpy as np
1717
import tensorflow as tf
1818

@@ -69,19 +69,18 @@ def add_featurizers(self,
6969
self.text_featurizer = text_featurizer
7070

7171
def call(self, inputs, training=False, **kwargs):
72-
inputs, inputs_length, _, _ = inputs.values()
73-
logits = self.encoder(inputs, training=training, **kwargs)
72+
logits = self.encoder(inputs["inputs"], training=training, **kwargs)
7473
logits = self.decoder(logits, training=training, **kwargs)
7574
return data_util.create_logits(
7675
logits=logits,
77-
logits_length=math_util.get_reduced_length(inputs_length, self.time_reduction_factor)
76+
logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
7877
)
7978

8079
# -------------------------------- GREEDY -------------------------------------
8180

8281
@tf.function
83-
def recognize(self, features: tf.Tensor, input_length: Optional[tf.Tensor]):
84-
logits = self(features, training=False)
82+
def recognize(self, inputs: Dict[str, tf.Tensor]):
83+
logits = self(inputs["inputs"], training=False)
8584
probs = tf.nn.softmax(logits)
8685

8786
def map_fn(prob): return tf.numpy_function(self._perform_greedy, inp=[prob], Tout=tf.string)
@@ -119,8 +118,8 @@ def recognize_tflite(self, signal):
119118
# -------------------------------- BEAM SEARCH -------------------------------------
120119

121120
@tf.function
122-
def recognize_beam(self, features: tf.Tensor, input_length: Optional[tf.Tensor], lm: bool = False):
123-
logits = self(features, training=False)
121+
def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False):
122+
logits = self(inputs["inputs"], training=False)
124123
probs = tf.nn.softmax(logits)
125124

126125
def map_fn(prob): return tf.numpy_function(self._perform_beam_search, inp=[prob, lm], Tout=tf.string)

tensorflow_asr/models/transducer/contextnet.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List
15+
from typing import Dict, List
1616
import tensorflow as tf
1717

1818
from ..encoders.contextnet import ContextNetEncoder, L2
1919
from .transducer import Transducer
20+
from ...utils import math_util
2021

2122

2223
class ContextNet(Transducer):
@@ -95,11 +96,7 @@ def encoder_inference(self, features: tf.Tensor, input_length: tf.Tensor):
9596
# -------------------------------- GREEDY -------------------------------------
9697

9798
@tf.function
98-
def recognize(self,
99-
features: tf.Tensor,
100-
input_length: tf.Tensor,
101-
parallel_iterations: int = 10,
102-
swap_memory: bool = True):
99+
def recognize(self, inputs: Dict[str, tf.Tensor]):
103100
"""
104101
RNN Transducer Greedy decoding
105102
Args:
@@ -108,12 +105,9 @@ def recognize(self,
108105
Returns:
109106
tf.Tensor: a batch of decoded transcripts
110107
"""
111-
encoded = self.encoder([features, input_length], training=False)
112-
return self._perform_greedy_batch(
113-
encoded, input_length,
114-
parallel_iterations=parallel_iterations,
115-
swap_memory=swap_memory
116-
)
108+
encoded = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=False)
109+
encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
110+
return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length)
117111

118112
def recognize_tflite(self, signal, predicted, prediction_states):
119113
"""
@@ -161,12 +155,7 @@ def recognize_tflite_with_timestamp(self, signal, predicted, states):
161155
# -------------------------------- BEAM SEARCH -------------------------------------
162156

163157
@tf.function
164-
def recognize_beam(self,
165-
features: tf.Tensor,
166-
input_length: tf.Tensor,
167-
lm: bool = False,
168-
parallel_iterations: int = 10,
169-
swap_memory: bool = True):
158+
def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False):
170159
"""
171160
RNN Transducer Beam Search
172161
Args:
@@ -176,9 +165,6 @@ def recognize_beam(self,
176165
Returns:
177166
tf.Tensor: a batch of decoded transcripts
178167
"""
179-
encoded = self.encoder([features, input_length], training=False)
180-
return self._perform_beam_search_batch(
181-
encoded, input_length, lm,
182-
parallel_iterations=parallel_iterations,
183-
swap_memory=swap_memory
184-
)
168+
encoded = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=False)
169+
encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
170+
return self._perform_beam_search_batch(encoded=encoded, encoded_length=encoded_length, lm=lm)

tensorflow_asr/models/transducer/rnn_transducer.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
""" http://arxiv.org/abs/1811.06621 """
1515

16+
from typing import Dict
1617
import tensorflow as tf
1718

1819
from ..layers.subsampling import TimeReduction
@@ -256,11 +257,7 @@ def encoder_inference(self, features: tf.Tensor, states: tf.Tensor):
256257
# -------------------------------- GREEDY -------------------------------------
257258

258259
@tf.function
259-
def recognize(self,
260-
features: tf.Tensor,
261-
input_length: tf.Tensor,
262-
parallel_iterations: int = 10,
263-
swap_memory: bool = True):
260+
def recognize(self, inputs: Dict[str, tf.Tensor]):
264261
"""
265262
RNN Transducer Greedy decoding
266263
Args:
@@ -269,10 +266,10 @@ def recognize(self,
269266
Returns:
270267
tf.Tensor: a batch of decoded transcripts
271268
"""
272-
batch_size, _, _, _ = shape_util.shape_list(features)
273-
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
274-
return self._perform_greedy_batch(encoded, input_length,
275-
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
269+
batch_size, _, _, _ = shape_util.shape_list(inputs["inputs"])
270+
encoded, _ = self.encoder.recognize(inputs["inputs"], self.encoder.get_initial_state(batch_size))
271+
encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
272+
return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length)
276273

277274
def recognize_tflite(self, signal, predicted, encoder_states, prediction_states):
278275
"""
@@ -321,12 +318,7 @@ def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, pre
321318
# -------------------------------- BEAM SEARCH -------------------------------------
322319

323320
@tf.function
324-
def recognize_beam(self,
325-
features: tf.Tensor,
326-
input_length: tf.Tensor,
327-
lm: bool = False,
328-
parallel_iterations: int = 10,
329-
swap_memory: bool = True):
321+
def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False):
330322
"""
331323
RNN Transducer Beam Search
332324
Args:
@@ -336,10 +328,10 @@ def recognize_beam(self,
336328
Returns:
337329
tf.Tensor: a batch of decoded transcripts
338330
"""
339-
batch_size, _, _, _ = shape_util.shape_list(features)
340-
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
341-
return self._perform_beam_search_batch(encoded, input_length, lm,
342-
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
331+
batch_size, _, _, _ = shape_util.shape_list(inputs["inputs"])
332+
encoded, _ = self.encoder.recognize(inputs["inputs"], self.encoder.get_initial_state(batch_size))
333+
encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor)
334+
return self._perform_beam_search_batch(encoded=encoded, encoded_length=encoded_length, lm=lm)
343335

344336
# -------------------------------- TFLITE -------------------------------------
345337

0 commit comments

Comments
 (0)