Skip to content

Commit f54de55

Browse files
committed
fix: update inference scripts for jasper + ds2 + rnn transducer
1 parent 5801018 commit f54de55

File tree

17 files changed

+724
-1077
lines changed

17 files changed

+724
-1077
lines changed

examples/conformer/inference/run_saved_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ def main(
4040
print("Transcript: ", "".join([chr(u) for u in transcript]))
4141

4242

43-
if __name__ == '__main__':
44-
fire.Fire(main)
43+
if __name__ == "__main__":
44+
fire.Fire(main)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import fire
17+
from tensorflow_asr.utils import env_util
18+
19+
logger = env_util.setup_environment()
20+
import tensorflow as tf
21+
22+
from tensorflow_asr.configs.config import Config
23+
from tensorflow_asr.models.transducer.contextnet import ContextNet
24+
from tensorflow_asr.helpers import exec_helpers, featurizer_helpers
25+
26+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
27+
28+
29+
def main(
30+
config: str = DEFAULT_YAML,
31+
h5: str = None,
32+
subwords: bool = False,
33+
sentence_piece: bool = False,
34+
output: str = None,
35+
):
36+
assert h5 and output
37+
tf.keras.backend.clear_session()
38+
tf.compat.v1.enable_control_flow_v2()
39+
40+
config = Config(config)
41+
speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
42+
config=config,
43+
subwords=subwords,
44+
sentence_piece=sentence_piece,
45+
)
46+
47+
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
48+
contextnet.make(speech_featurizer.shape)
49+
contextnet.load_weights(h5, by_name=True)
50+
contextnet.summary(line_length=100)
51+
contextnet.add_featurizers(speech_featurizer, text_featurizer)
52+
53+
exec_helpers.convert_tflite(model=contextnet, output=output)
54+
55+
56+
if __name__ == "__main__":
57+
fire.Fire(main)

examples/contextnet/test.py

Lines changed: 46 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,97 +13,58 @@
1313
# limitations under the License.
1414

1515
import os
16-
from tqdm import tqdm
17-
import argparse
18-
from tensorflow_asr.utils import env_util, file_util
16+
import fire
17+
from tensorflow_asr.utils import env_util
1918

2019
logger = env_util.setup_environment()
2120
import tensorflow as tf
2221

23-
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24-
25-
tf.keras.backend.clear_session()
26-
27-
parser = argparse.ArgumentParser(prog="Contextnet Testing")
28-
29-
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
30-
31-
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
32-
33-
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
34-
35-
parser.add_argument("--bs", type=int, default=None, help="Test batch size")
36-
37-
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
38-
39-
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
40-
41-
parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")
42-
43-
parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
44-
45-
parser.add_argument("--output", type=str, default="test.tsv", help="Result filepath")
46-
47-
args = parser.parse_args()
48-
49-
assert args.saved
50-
51-
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
52-
53-
env_util.setup_devices([args.device], cpu=args.cpu)
54-
5522
from tensorflow_asr.configs.config import Config
56-
from tensorflow_asr.datasets.asr_dataset import ASRSliceDataset
57-
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
58-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer, CharFeaturizer
5923
from tensorflow_asr.models.transducer.contextnet import ContextNet
60-
from tensorflow_asr.utils import app_util
24+
from tensorflow_asr.helpers import exec_helpers, dataset_helpers, featurizer_helpers
6125

62-
config = Config(args.config)
63-
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
64-
65-
if args.sentence_piece:
66-
logger.info("Use SentencePiece ...")
67-
text_featurizer = SentencePieceFeaturizer(config.decoder_config)
68-
elif args.subwords:
69-
logger.info("Use subwords ...")
70-
text_featurizer = SubwordFeaturizer(config.decoder_config)
71-
else:
72-
logger.info("Use characters ...")
73-
text_featurizer = CharFeaturizer(config.decoder_config)
74-
75-
tf.random.set_seed(0)
76-
77-
test_dataset = ASRSliceDataset(
78-
speech_featurizer=speech_featurizer,
79-
text_featurizer=text_featurizer,
80-
**vars(config.learning_config.test_dataset_config)
81-
)
82-
83-
# build model
84-
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
85-
contextnet.make(speech_featurizer.shape)
86-
contextnet.load_weights(args.saved, by_name=True)
87-
contextnet.summary(line_length=100)
88-
contextnet.add_featurizers(speech_featurizer, text_featurizer)
26+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
8927

90-
batch_size = args.bs or config.learning_config.running_config.batch_size
91-
test_data_loader = test_dataset.create(batch_size)
9228

93-
with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath:
94-
overwrite = True
95-
if tf.io.gfile.exists(filepath):
96-
overwrite = input(f"Overwrite existing result file {filepath} ? (y/n): ").lower() == "y"
97-
if overwrite:
98-
results = contextnet.predict(test_data_loader, verbose=1)
99-
logger.info(f"Saving result to {args.output} ...")
100-
with open(filepath, "w") as openfile:
101-
openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n")
102-
progbar = tqdm(total=test_dataset.total_steps, unit="batch")
103-
for i, pred in enumerate(results):
104-
groundtruth, greedy, beamsearch = [x.decode('utf-8') for x in pred]
105-
path, duration, _ = test_dataset.entries[i]
106-
openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n")
107-
progbar.update(1)
108-
progbar.close()
109-
app_util.evaluate_results(filepath)
29+
def main(
30+
config: str = DEFAULT_YAML,
31+
saved: str = None,
32+
mxp: bool = False,
33+
bs: int = None,
34+
sentence_piece: bool = False,
35+
subwords: bool = False,
36+
device: int = 0,
37+
cpu: bool = False,
38+
output: str = "test.tsv",
39+
):
40+
assert saved and output
41+
tf.random.set_seed(0)
42+
tf.keras.backend.clear_session()
43+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": mxp})
44+
env_util.setup_devices([device], cpu=cpu)
45+
46+
config = Config(config)
47+
48+
speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
49+
config=config,
50+
subwords=subwords,
51+
sentence_piece=sentence_piece,
52+
)
53+
54+
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
55+
contextnet.make(speech_featurizer.shape)
56+
contextnet.load_weights(saved, by_name=True)
57+
contextnet.summary(line_length=100)
58+
contextnet.add_featurizers(speech_featurizer, text_featurizer)
59+
60+
test_dataset = dataset_helpers.prepare_testing_datasets(
61+
config=config, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer
62+
)
63+
batch_size = bs or config.learning_config.running_config.batch_size
64+
test_data_loader = test_dataset.create(batch_size)
65+
66+
exec_helpers.run_testing(model=contextnet, test_dataset=test_dataset, test_data_loader=test_data_loader, output=output)
67+
68+
69+
if __name__ == "__main__":
70+
fire.Fire(main)

examples/contextnet/tflite.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)