Skip to content

Commit e318348

Browse files
committed
fix(model): support saved model and tflite conversion
1 parent ba6ab51 commit e318348

File tree

9 files changed

+150
-42
lines changed

9 files changed

+150
-42
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ docker-compose up -d
120120

121121
- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh`
122122

123-
- For _training_ **Transducer Models** with RNNT Loss from [warp-transducer](https://github.com/HawkAaron/warp-transducer), run `export CUDA_HOME=/usr/local/cuda && ./scripts/install_rnnt_loss.sh` (**Note**: only `export CUDA_HOME` when you have CUDA)
123+
- For _training_ **Transducer Models** with RNNT Loss in TF, make sure that [warp-transducer](https://github.com/HawkAaron/warp-transducer) **is not installed** (by simply run `pip3 uninstall warprnnt-tensorflow`) (**Recommended**)
124124

125-
- For _training_ **Transducer Models** with RNNT Loss in TF, make sure that [warp-transducer](https://github.com/HawkAaron/warp-transducer) **is not installed** (by simply run `pip3 uninstall warprnnt-tensorflow`)
125+
- For _training_ **Transducer Models** with RNNT Loss from [warp-transducer](https://github.com/HawkAaron/warp-transducer), run `export CUDA_HOME=/usr/local/cuda && ./scripts/install_rnnt_loss.sh` (**Note**: only `export CUDA_HOME` when you have CUDA)
126126

127127
- For _mixed precision training_, use flag `--mxp` when running python scripts from [examples](./examples)
128128

examples/conformer/saved_model.py renamed to examples/conformer/inference/gen_saved_model.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,24 @@
6969
conformer.add_featurizers(speech_featurizer, text_featurizer)
7070

7171

72-
# TODO: Support saved model conversion
73-
# class ConformerModule(tf.Module):
74-
# def __init__(self, model: Conformer, name=None):
75-
# super().__init__(name=name)
76-
# self.model = model
77-
# self.pred = model.make_tflite_function()
78-
79-
80-
# model = ConformerModule(model=conformer)
81-
# tf.saved_model.save(model, args.output_dir)
82-
conformer.save(args.output_dir, include_optimizer=False, save_format="tf")
72+
class ConformerModule(tf.Module):
73+
def __init__(self, model: Conformer, name=None):
74+
super().__init__(name=name)
75+
self.model = model
76+
self.num_rnns = config.model_config["prediction_num_rnns"]
77+
self.rnn_units = config.model_config["prediction_rnn_units"]
78+
self.rnn_nstates = 2 if config.model_config["prediction_rnn_type"] == "lstm" else 1
79+
80+
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
81+
def pred(self, signal):
82+
predicted = tf.constant(0, dtype=tf.int32)
83+
states = tf.zeros([self.num_rnns, self.rnn_nstates, 1, self.rnn_units], dtype=tf.float32)
84+
features = self.model.speech_featurizer.tf_extract(signal)
85+
encoded = self.model.encoder_inference(features)
86+
hypothesis = self.model._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=False)
87+
transcript = self.model.text_featurizer.indices2upoints(hypothesis.prediction)
88+
return transcript
89+
90+
91+
module = ConformerModule(model=conformer)
92+
tf.saved_model.save(module, export_dir=args.output_dir, signatures=module.pred.get_concrete_function())

examples/conformer/tflite.py renamed to examples/conformer/inference/gen_tflite_model.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,25 @@
2727
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2828

2929
tf.keras.backend.clear_session()
30+
tf.compat.v1.enable_control_flow_v2()
3031

3132
parser = argparse.ArgumentParser(prog="Conformer TFLite")
3233

3334
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3435

35-
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
36+
parser.add_argument("--h5", type=str, default=None, help="Path to saved model")
3637

37-
parser.add_argument("--subwords", action="store_true", help="Use subwords")
38-
39-
parser.add_argument("--vocabulary", type=str, default=None, required=False,
40-
help="Path to vocabulary. Overrides path in config, if given.")
38+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
4139

4240
parser.add_argument("output", type=str, default=None, help="TFLite file path to be exported")
4341

4442
args = parser.parse_args()
4543

46-
assert args.saved and args.output
44+
assert args.h5 and args.output
4745

4846
config = Config(args.config)
4947
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
5048

51-
if args.vocabulary is not None:
52-
config.decoder_config["vocabulary"] = args.vocabulary
53-
5449
if args.subwords:
5550
text_featurizer = SubwordFeaturizer(config.decoder_config)
5651
else:
@@ -59,7 +54,7 @@
5954
# build model
6055
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6156
conformer.make(speech_featurizer.shape)
62-
conformer.load_weights(args.saved, by_name=True)
57+
conformer.load_weights(args.h5, by_name=True)
6358
conformer.summary(line_length=100)
6459
conformer.add_featurizers(speech_featurizer, text_featurizer)
6560

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
from tensorflow_asr.utils import env_util
19+
20+
logger = env_util.setup_environment()
21+
import tensorflow as tf
22+
23+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24+
25+
tf.keras.backend.clear_session()
26+
27+
parser = argparse.ArgumentParser()
28+
29+
parser.add_argument("--saved_model", type=str, default=None, help="The file path of saved model")
30+
31+
parser.add_argument("filename", type=str, default=None, help="Audio file path")
32+
33+
args = parser.parse_args()
34+
35+
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
36+
37+
module = tf.saved_model.load(export_dir=args.saved_model)
38+
39+
signal = read_raw_audio(args.filename)
40+
transcript = module.pred(signal)
41+
42+
print("Transcript: ", "".join([chr(u) for u in transcript]))
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import tensorflow as tf
17+
18+
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
19+
20+
parser = argparse.ArgumentParser()
21+
22+
parser.add_argument("filename", metavar="FILENAME", help="Audio file to be played back")
23+
24+
parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite")
25+
26+
parser.add_argument("--blank", type=int, default=0, help="Blank index")
27+
28+
parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network")
29+
30+
parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network")
31+
32+
parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network")
33+
34+
args = parser.parse_args()
35+
36+
tflitemodel = tf.lite.Interpreter(model_path=args.tflite)
37+
38+
signal = read_raw_audio(args.filename)
39+
40+
input_details = tflitemodel.get_input_details()
41+
output_details = tflitemodel.get_output_details()
42+
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
43+
tflitemodel.allocate_tensors()
44+
tflitemodel.set_tensor(input_details[0]["index"], signal)
45+
tflitemodel.set_tensor(input_details[1]["index"], tf.constant(args.blank, dtype=tf.int32))
46+
tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32))
47+
tflitemodel.invoke()
48+
hyp = tflitemodel.get_tensor(output_details[0]["index"])
49+
50+
print("".join([chr(u) for u in hyp]))

examples/demonstration/conformer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
import argparse
17-
from tensorflow_asr.utils import env_util, math_util
17+
from tensorflow_asr.utils import env_util
1818

1919
logger = env_util.setup_environment()
2020
import tensorflow as tf
@@ -79,12 +79,14 @@
7979
logger.info(f"Transcript: {transcript[0].numpy().decode('UTF-8')}")
8080
elif args.timestamp:
8181
transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
82-
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
82+
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()
83+
)
8384
logger.info(f"Transcript: {transcript}")
8485
logger.info(f"Start time: {stime}")
8586
logger.info(f"End time: {etime}")
8687
else:
8788
code_points, _, _ = conformer.recognize_tflite(
88-
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
89-
transcript = tf.strings.unicode_encode(code_points, 'UTF-8').numpy().decode('UTF-8')
89+
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()
90+
)
91+
transcript = tf.strings.unicode_encode(code_points, "UTF-8").numpy().decode("UTF-8")
9092
logger.info(f"Transcript: {transcript}")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def parse_requirements(lines: List[str]):
4141

4242
setup(
4343
name="TensorFlowASR",
44-
version="1.0.3",
44+
version="1.0.2",
4545
author="Huy Le Nguyen",
4646
author_email="[email protected]",
4747
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/models/layers/embedding.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ def __init__(
3232
self.regularizer = tf.keras.regularizers.get(regularizer)
3333
self.initializer = tf.keras.initializers.get(initializer)
3434

35-
def build(
36-
self,
37-
input_shape,
38-
):
35+
def build(self, input_shape):
3936
self.embeddings = self.add_weight(
4037
name="embeddings",
4138
dtype=tf.float32,
@@ -47,10 +44,11 @@ def build(
4744
)
4845
self.built = True
4946

50-
def call(
51-
self,
52-
inputs,
53-
):
47+
def call(self, inputs):
48+
outputs = tf.cast(inputs, dtype=tf.int32)
49+
return tf.nn.embedding_lookup(self.embeddings, outputs)
50+
51+
def recognize_tflite(self, inputs):
5452
outputs = tf.cast(tf.expand_dims(inputs, axis=-1), dtype=tf.int32)
5553
return tf.gather_nd(self.embeddings, outputs)
5654

tensorflow_asr/models/transducer/base_transducer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def call(self, inputs, training=False, **kwargs):
104104
outputs = rnn["projection"](outputs, training=training)
105105
return outputs
106106

107-
def recognize(self, inputs, states):
107+
def recognize(self, inputs, states, tflite: bool = False):
108108
"""Recognize function for prediction network
109109
110110
Args:
@@ -115,7 +115,10 @@ def recognize(self, inputs, states):
115115
tf.Tensor: outputs with shape [1, 1, P]
116116
tf.Tensor: new states with shape [num_lstms, 2, 1, P]
117117
"""
118-
outputs = self.embed(inputs, training=False)
118+
if tflite:
119+
outputs = self.embed.recognize_tflite(inputs)
120+
else:
121+
outputs = self.embed(inputs, training=False)
119122
outputs = self.do(outputs, training=False)
120123
new_states = []
121124
for i, rnn in enumerate(self.rnns):
@@ -390,6 +393,7 @@ def decoder_inference(
390393
encoded: tf.Tensor,
391394
predicted: tf.Tensor,
392395
states: tf.Tensor,
396+
tflite: bool = False,
393397
):
394398
"""Infer function for decoder
395399
@@ -404,7 +408,7 @@ def decoder_inference(
404408
with tf.name_scope(f"{self.name}_decoder"):
405409
encoded = tf.reshape(encoded, [1, 1, -1]) # [E] => [1, 1, E]
406410
predicted = tf.reshape(predicted, [1, 1]) # [] => [1, 1]
407-
y, new_states = self.predict_net.recognize(predicted, states) # [1, 1, P], states
411+
y, new_states = self.predict_net.recognize(predicted, states, tflite=tflite) # [1, 1, P], states
408412
ytu = tf.nn.log_softmax(self.joint_net([encoded, y], training=False)) # [1, 1, V]
409413
ytu = tf.reshape(ytu, shape=[-1]) # [1, 1, V] => [V]
410414
return ytu, new_states
@@ -455,7 +459,7 @@ def recognize_tflite(
455459
"""
456460
features = self.speech_featurizer.tf_extract(signal)
457461
encoded = self.encoder_inference(features)
458-
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
462+
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=True)
459463
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
460464
return transcript, hypothesis.index, hypothesis.states
461465

@@ -467,7 +471,7 @@ def recognize_tflite_with_timestamp(
467471
):
468472
features = self.speech_featurizer.tf_extract(signal)
469473
encoded = self.encoder_inference(features)
470-
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
474+
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=True)
471475
indices = self.text_featurizer.normalize_indices(hypothesis.prediction)
472476
upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length]
473477

@@ -540,6 +544,7 @@ def _perform_greedy(
540544
states: tf.Tensor,
541545
parallel_iterations: int = 10,
542546
swap_memory: bool = False,
547+
tflite: bool = False,
543548
):
544549
with tf.name_scope(f"{self.name}_greedy"):
545550
time = tf.constant(0, dtype=tf.int32)
@@ -566,6 +571,7 @@ def body(_time, _hypothesis):
566571
encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])),
567572
predicted=_hypothesis.index,
568573
states=_hypothesis.states,
574+
tflite=tflite,
569575
)
570576
_predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
571577

@@ -605,6 +611,7 @@ def _perform_greedy_v2(
605611
states: tf.Tensor,
606612
parallel_iterations: int = 10,
607613
swap_memory: bool = False,
614+
tflite: bool = False,
608615
):
609616
"""Ref: https://arxiv.org/pdf/1801.00841.pdf"""
610617
with tf.name_scope(f"{self.name}_greedy_v2"):
@@ -632,6 +639,7 @@ def body(_time, _hypothesis):
632639
encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])),
633640
predicted=_hypothesis.index,
634641
states=_hypothesis.states,
642+
tflite=tflite,
635643
)
636644
_predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
637645

@@ -736,6 +744,7 @@ def _perform_beam_search(
736744
lm: bool = False,
737745
parallel_iterations: int = 10,
738746
swap_memory: bool = True,
747+
tflite: bool = False,
739748
):
740749
with tf.name_scope(f"{self.name}_beam_search"):
741750
beam_width = tf.cond(
@@ -834,7 +843,9 @@ def beam_body(beam, beam_width, A, A_i, B):
834843
)
835844
A_i = tf.cond(tf.equal(A_i, 0), true_fn=lambda: A_i, false_fn=lambda: A_i - 1)
836845

837-
ytu, new_states = self.decoder_inference(encoded=encoded_t, predicted=y_hat_index, states=y_hat_states)
846+
ytu, new_states = self.decoder_inference(
847+
encoded=encoded_t, predicted=y_hat_index, states=y_hat_states, tflite=tflite
848+
)
838849

839850
def predict_condition(pred, A, A_i, B):
840851
return tf.less(pred, self.text_featurizer.num_classes)

0 commit comments

Comments
 (0)