Skip to content

Commit 937ef5a

Browse files
authored
Merge pull request #233 from TensorSpeech/fix/models
Support saved model conversion for transducer models
2 parents a76569e + 196c68b commit 937ef5a

File tree

10 files changed

+218
-66
lines changed

10 files changed

+218
-66
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ docker-compose up -d
120120

121121
- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh`
122122

123-
- For _training_ **Transducer Models** with RNNT Loss from [warp-transducer](https://github.com/HawkAaron/warp-transducer), run `export CUDA_HOME=/usr/local/cuda && ./scripts/install_rnnt_loss.sh` (**Note**: only `export CUDA_HOME` when you have CUDA)
123+
- For _training_ **Transducer Models** with RNNT Loss in TF, make sure that [warp-transducer](https://github.com/HawkAaron/warp-transducer) **is not installed** (by simply run `pip3 uninstall warprnnt-tensorflow`) (**Recommended**)
124124

125-
- For _training_ **Transducer Models** with RNNT Loss in TF, make sure that [warp-transducer](https://github.com/HawkAaron/warp-transducer) **is not installed** (by simply run `pip3 uninstall warprnnt-tensorflow`)
125+
- For _training_ **Transducer Models** with RNNT Loss from [warp-transducer](https://github.com/HawkAaron/warp-transducer), run `export CUDA_HOME=/usr/local/cuda && ./scripts/install_rnnt_loss.sh` (**Note**: only `export CUDA_HOME` when you have CUDA)
126126

127127
- For _mixed precision training_, use flag `--mxp` when running python scripts from [examples](./examples)
128128

examples/conformer/saved_model.py renamed to examples/conformer/inference/gen_saved_model.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,24 @@
6969
conformer.add_featurizers(speech_featurizer, text_featurizer)
7070

7171

72-
# TODO: Support saved model conversion
73-
# class ConformerModule(tf.Module):
74-
# def __init__(self, model: Conformer, name=None):
75-
# super().__init__(name=name)
76-
# self.model = model
77-
# self.pred = model.make_tflite_function()
78-
79-
80-
# model = ConformerModule(model=conformer)
81-
# tf.saved_model.save(model, args.output_dir)
82-
conformer.save(args.output_dir, include_optimizer=False, save_format="tf")
72+
class ConformerModule(tf.Module):
73+
def __init__(self, model: Conformer, name=None):
74+
super().__init__(name=name)
75+
self.model = model
76+
self.num_rnns = config.model_config["prediction_num_rnns"]
77+
self.rnn_units = config.model_config["prediction_rnn_units"]
78+
self.rnn_nstates = 2 if config.model_config["prediction_rnn_type"] == "lstm" else 1
79+
80+
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
81+
def pred(self, signal):
82+
predicted = tf.constant(0, dtype=tf.int32)
83+
states = tf.zeros([self.num_rnns, self.rnn_nstates, 1, self.rnn_units], dtype=tf.float32)
84+
features = self.model.speech_featurizer.tf_extract(signal)
85+
encoded = self.model.encoder_inference(features)
86+
hypothesis = self.model._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=False)
87+
transcript = self.model.text_featurizer.indices2upoints(hypothesis.prediction)
88+
return transcript
89+
90+
91+
module = ConformerModule(model=conformer)
92+
tf.saved_model.save(module, export_dir=args.output_dir, signatures=module.pred.get_concrete_function())

examples/conformer/tflite.py renamed to examples/conformer/inference/gen_tflite_model.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,25 @@
2727
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2828

2929
tf.keras.backend.clear_session()
30+
tf.compat.v1.enable_control_flow_v2()
3031

3132
parser = argparse.ArgumentParser(prog="Conformer TFLite")
3233

3334
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3435

35-
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
36+
parser.add_argument("--h5", type=str, default=None, help="Path to saved model")
3637

37-
parser.add_argument("--subwords", action="store_true", help="Use subwords")
38-
39-
parser.add_argument("--vocabulary", type=str, default=None, required=False,
40-
help="Path to vocabulary. Overrides path in config, if given.")
38+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
4139

4240
parser.add_argument("output", type=str, default=None, help="TFLite file path to be exported")
4341

4442
args = parser.parse_args()
4543

46-
assert args.saved and args.output
44+
assert args.h5 and args.output
4745

4846
config = Config(args.config)
4947
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
5048

51-
if args.vocabulary is not None:
52-
config.decoder_config["vocabulary"] = args.vocabulary
53-
5449
if args.subwords:
5550
text_featurizer = SubwordFeaturizer(config.decoder_config)
5651
else:
@@ -59,7 +54,7 @@
5954
# build model
6055
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6156
conformer.make(speech_featurizer.shape)
62-
conformer.load_weights(args.saved, by_name=True)
57+
conformer.load_weights(args.h5, by_name=True)
6358
conformer.summary(line_length=100)
6459
conformer.add_featurizers(speech_featurizer, text_featurizer)
6560

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
from tensorflow_asr.utils import env_util
19+
20+
logger = env_util.setup_environment()
21+
import tensorflow as tf
22+
23+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24+
25+
tf.keras.backend.clear_session()
26+
27+
parser = argparse.ArgumentParser()
28+
29+
parser.add_argument("--saved_model", type=str, default=None, help="The file path of saved model")
30+
31+
parser.add_argument("filename", type=str, default=None, help="Audio file path")
32+
33+
args = parser.parse_args()
34+
35+
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
36+
37+
module = tf.saved_model.load(export_dir=args.saved_model)
38+
39+
signal = read_raw_audio(args.filename)
40+
transcript = module.pred(signal)
41+
42+
print("Transcript: ", "".join([chr(u) for u in transcript]))
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import tensorflow as tf
17+
18+
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
19+
20+
parser = argparse.ArgumentParser()
21+
22+
parser.add_argument("filename", metavar="FILENAME", help="Audio file to be played back")
23+
24+
parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite")
25+
26+
parser.add_argument("--blank", type=int, default=0, help="Blank index")
27+
28+
parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network")
29+
30+
parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network")
31+
32+
parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network")
33+
34+
args = parser.parse_args()
35+
36+
tflitemodel = tf.lite.Interpreter(model_path=args.tflite)
37+
38+
signal = read_raw_audio(args.filename)
39+
40+
input_details = tflitemodel.get_input_details()
41+
output_details = tflitemodel.get_output_details()
42+
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
43+
tflitemodel.allocate_tensors()
44+
tflitemodel.set_tensor(input_details[0]["index"], signal)
45+
tflitemodel.set_tensor(input_details[1]["index"], tf.constant(args.blank, dtype=tf.int32))
46+
tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32))
47+
tflitemodel.invoke()
48+
hyp = tflitemodel.get_tensor(output_details[0]["index"])
49+
50+
print("".join([chr(u) for u in hyp]))

examples/demonstration/conformer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
import argparse
17-
from tensorflow_asr.utils import env_util, math_util
17+
from tensorflow_asr.utils import env_util
1818

1919
logger = env_util.setup_environment()
2020
import tensorflow as tf
@@ -79,12 +79,14 @@
7979
logger.info(f"Transcript: {transcript[0].numpy().decode('UTF-8')}")
8080
elif args.timestamp:
8181
transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
82-
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
82+
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()
83+
)
8384
logger.info(f"Transcript: {transcript}")
8485
logger.info(f"Start time: {stime}")
8586
logger.info(f"End time: {etime}")
8687
else:
8788
code_points, _, _ = conformer.recognize_tflite(
88-
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
89-
transcript = tf.strings.unicode_encode(code_points, 'UTF-8').numpy().decode('UTF-8')
89+
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()
90+
)
91+
transcript = tf.strings.unicode_encode(code_points, "UTF-8").numpy().decode("UTF-8")
9092
logger.info(f"Transcript: {transcript}")

requirements.txt

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,45 @@ librosa==0.8.1
88
PyYAML==5.4.1
99
Pillow==8.3.2
1010
black==21.7b0
11-
flake8==3.9.2
11+
flake8==3.9.2
12+
sounddevice==0.4.3
13+
14+
# extra=tf2.3
15+
tensorflow~=2.3.0
16+
tensorflow-text~=2.3.0
17+
tensorflow-io~=0.16.0
18+
19+
# extra=tf2.3-gpu
20+
tensorflow-gpu~=2.3.0
21+
tensorflow-text~=2.3.0
22+
tensorflow-io~=0.16.0
23+
24+
# extra=tf2.4
25+
tensorflow~=2.4.0
26+
tensorflow-text~=2.4.0
27+
tensorflow-io~=0.17.0
28+
29+
# extra=tf2.4-gpu
30+
tensorflow-gpu~=2.4.0
31+
tensorflow-text~=2.4.0
32+
tensorflow-io~=0.17.0
33+
34+
# extra=tf2.5
35+
tensorflow~=2.5.0
36+
tensorflow-text~=2.5.0
37+
tensorflow-io~=0.18.0
38+
39+
# extra=tf2.5-gpu
40+
tensorflow-gpu~=2.5.0
41+
tensorflow-text~=2.5.0
42+
tensorflow-io~=0.18.0
43+
44+
# extra=tf2.6
45+
tensorflow~=2.6.0
46+
tensorflow-text~=2.6.0
47+
tensorflow-io~=0.20.0
48+
49+
# extra=tf2.6-gpu
50+
tensorflow-gpu~=2.6.0
51+
tensorflow-text~=2.6.0
52+
tensorflow-io~=0.20.0

setup.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,34 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
15+
from setuptools import find_packages, setup
16+
from typing import List
17+
from collections import defaultdict
1618

17-
import setuptools
1819

19-
readme_path = os.path.join(os.path.dirname(__file__), "README.md")
20+
def parse_requirements(lines: List[str]):
21+
extras_requires = defaultdict(list)
22+
extra = "requires"
23+
for line in lines:
24+
line = line.strip()
25+
if line.startswith("# extra="):
26+
extra = line.split("=")[1].strip()
27+
continue
28+
if line and line[0] != "#":
29+
lib_package = line.split("#")[0].strip() # split comments
30+
extras_requires[extra].append(lib_package)
31+
install_requires = extras_requires.pop("requires")
32+
return install_requires, extras_requires
2033

21-
with open(readme_path, "r", encoding="utf-8") as fh:
22-
long_description = fh.read()
2334

24-
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
35+
with open("requirements.txt", "r", encoding="utf-8") as fr:
36+
install_requires, extras_requires = parse_requirements(fr.readlines())
2537

26-
with open(requirements_path, "r") as fr:
27-
requirements = fr.read().splitlines()
28-
print(requirements)
38+
with open("README.md", "r", encoding="utf-8") as fh:
39+
long_description = fh.read()
2940

3041

31-
setuptools.setup(
42+
setup(
3243
name="TensorFlowASR",
3344
version="1.0.2",
3445
author="Huy Le Nguyen",
@@ -37,26 +48,18 @@
3748
long_description=long_description,
3849
long_description_content_type="text/markdown",
3950
url="https://github.com/TensorSpeech/TensorFlowASR",
40-
packages=setuptools.find_packages(include=["tensorflow_asr*"]),
41-
install_requires=requirements,
42-
extras_require={
43-
"tf2.3": ["tensorflow~=2.3.0", "tensorflow-text~=2.3.0", "tensorflow-io~=0.16.0"],
44-
"tf2.3-gpu": ["tensorflow-gpu~=2.3.0", "tensorflow-text~=2.3.0", "tensorflow-io~=0.16.0"],
45-
"tf2.4": ["tensorflow~=2.4.0", "tensorflow-text~=2.4.0", "tensorflow-io~=0.17.0"],
46-
"tf2.4-gpu": ["tensorflow-gpu~=2.4.0", "tensorflow-text~=2.4.0", "tensorflow-io~=0.17.0"],
47-
"tf2.5": ["tensorflow~=2.5.0", "tensorflow-text~=2.5.0", "tensorflow-io~=0.18.0"],
48-
"tf2.5-gpu": ["tensorflow-gpu~=2.5.0", "tensorflow-text~=2.5.0", "tensorflow-io~=0.18.0"],
49-
"tf2.6": ["tensorflow~=2.6.0", "tensorflow-text~=2.6.0rc0", "tensorflow-io~=0.20.0"],
50-
"tf2.6-gpu": ["tensorflow-gpu~=2.6.0", "tensorflow-text~=2.6.0rc0", "tensorflow-io~=0.20.0"],
51-
},
51+
packages=find_packages(include=("tensorflow_asr", "tensorflow_asr.*")),
52+
install_requires=install_requires,
53+
extras_require=extras_requires,
5254
classifiers=[
5355
"Programming Language :: Python :: 3.6",
5456
"Programming Language :: Python :: 3.7",
5557
"Programming Language :: Python :: 3.8",
58+
"Programming Language :: Python :: 3.9",
5659
"Intended Audience :: Science/Research",
5760
"Operating System :: POSIX :: Linux",
5861
"License :: OSI Approved :: Apache Software License",
5962
"Topic :: Software Development :: Libraries :: Python Modules",
6063
],
61-
python_requires=">=3.6",
64+
python_requires=">=3.6, <4",
6265
)

tensorflow_asr/models/layers/embedding.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ def __init__(
3232
self.regularizer = tf.keras.regularizers.get(regularizer)
3333
self.initializer = tf.keras.initializers.get(initializer)
3434

35-
def build(
36-
self,
37-
input_shape,
38-
):
35+
def build(self, input_shape):
3936
self.embeddings = self.add_weight(
4037
name="embeddings",
4138
dtype=tf.float32,
@@ -47,12 +44,13 @@ def build(
4744
)
4845
self.built = True
4946

50-
def call(
51-
self,
52-
inputs,
53-
):
47+
def call(self, inputs):
48+
outputs = tf.cast(inputs, dtype=tf.int32)
49+
return tf.nn.embedding_lookup(self.embeddings, outputs)
50+
51+
def recognize_tflite(self, inputs):
5452
outputs = tf.cast(tf.expand_dims(inputs, axis=-1), dtype=tf.int32)
55-
return tf.gather_nd(self.embeddings, outputs)
53+
return tf.gather_nd(self.embeddings, outputs) # https://github.com/tensorflow/tensorflow/issues/42410
5654

5755
def get_config(self):
5856
conf = super(Embedding, self).get_config()

0 commit comments

Comments
 (0)