Skip to content

Commit 3393571

Browse files
authored
Merge pull request #48 from TensorSpeech/dev/example
Update demonstration example
2 parents 1785335 + 47176ae commit 3393571

File tree

5 files changed

+173
-36
lines changed

5 files changed

+173
-36
lines changed

README.md

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ TensorFlowASR implements some automatic speech recognition architectures such as
3030
- Support `transducer` tflite greedy decoding (conversion and invocation)
3131
- Distributed training using `tf.distribute.MirroredStrategy`
3232

33+
## Table of Contents
34+
<!-- TOC -->
35+
36+
- [What's New?](#whats-new)
37+
- [Table of Contents](#table-of-contents)
38+
- [:yum: Supported Models](#yum-supported-models)
39+
- [Installation](#installation)
40+
- [Installing via PyPi](#installing-via-pypi)
41+
- [Installing from source](#installing-from-source)
42+
- [Setup training and testing](#setup-training-and-testing)
43+
- [TFLite Convertion](#tflite-convertion)
44+
- [Features Extraction](#features-extraction)
45+
- [Augmentations](#augmentations)
46+
- [Training & Testing](#training--testing)
47+
- [Corpus Sources and Pretrained Models](#corpus-sources-and-pretrained-models)
48+
- [English](#english)
49+
- [Vietnamese](#vietnamese)
50+
- [German](#german)
51+
- [References & Credits](#references--credits)
52+
53+
<!-- /TOC -->
54+
3355
## :yum: Supported Models
3456

3557
- **CTCModel** (End2end models using CTC Loss for training)
@@ -43,26 +65,44 @@ TensorFlowASR implements some automatic speech recognition architectures such as
4365
- **Streaming Transducer** (Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621))
4466
See [examples/streaming_transducer](./examples/streaming_transducer)
4567

46-
## Setup Environment and Datasets
68+
## Installation
4769

48-
Install tensorflow: `pip3 install -U tensorflow` or `pip3 install tf-nightly` (for using tflite)
70+
For training and testing, you should use `git clone` for installing necessary packages from other authors (`ctc_decoders`, `rnnt_loss`, etc.)
4971

50-
Install packages (choose _one_ of these options):
72+
### Installing via PyPi
5173

52-
- Run `pip3 install -U TensorFlowASR`
53-
- Clone the repo and run `python3 setup.py install` in the repo's directory
74+
Run `pip3 install -U TensorFlowASR`
5475

55-
For **setting up datasets**, see [datasets](./tensorflow_asr/datasets/README.md)
76+
### Installing from source
5677

57-
- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh`
78+
```bash
79+
git clone https://github.com/TensorSpeech/TensorFlowASR.git
80+
cd TensorFlowASR
81+
python3 setup.py install
82+
```
5883

59-
- For _training_ **Transducer Models**, export `CUDA_HOME` and run `./scripts/install_rnnt_loss.sh`
84+
For anaconda3:
85+
86+
```bash
87+
conda create -y -n tfasr tensorflow-gpu python=3.7 # tensorflow if using CPU
88+
conda activate tfasr
89+
pip install -U tensorflow-gpu # upgrade to latest version of tensorflow
90+
git clone https://github.com/TensorSpeech/TensorFlowASR.git
91+
cd TensorFlowASR
92+
python setup.py install
93+
```
94+
95+
## Setup training and testing
96+
97+
- For datasets, see [datasets](./tensorflow_asr/datasets/README.md)
98+
99+
- For _training, testing and using_ **CTC Models**, run `./scripts/install_ctc_decoders.sh`
60100

61-
- Method `tensorflow_asr.utils.setup_environment()` enable **mixed_precision** if available.
101+
- For _training_ **Transducer Models**, run `export CUDA_HOME=/usr/local/cuda && ./scripts/install_rnnt_loss.sh` (**Note**: only `export CUDA_HOME` when you have CUDA)
62102

63-
- To enable XLA, run `TF_XLA_FLAGS=--tf_xla_auto_jit=2 $python_train_script`
103+
- For _mixed precision training_, use flag `--mxp` when running python scripts from [examples](./examples)
64104

65-
Clean up: `python3 setup.py clean --all` (this will remove `/build` contents)
105+
- For _enabling XLA_, run `TF_XLA_FLAGS=--tf_xla_auto_jit=2 python3 $path_to_py_script`)
66106

67107
## TFLite Convertion
68108

examples/demonstration/conformer.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,65 @@
1313
# limitations under the License.
1414

1515
import argparse
16-
import tensorflow as tf
16+
from tensorflow_asr.utils import setup_environment, setup_devices
1717

18-
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
18+
setup_environment()
19+
import tensorflow as tf
1920

2021
parser = argparse.ArgumentParser(prog="Conformer non streaming")
2122

2223
parser.add_argument("filename", metavar="FILENAME",
2324
help="audio file to be played back")
2425

25-
parser.add_argument("--tflite", type=str, default=None,
26-
help="Path to conformer tflite")
26+
parser.add_argument("--config", type=str, default=None,
27+
help="Path to conformer config yaml")
28+
29+
parser.add_argument("--saved", type=str, default=None,
30+
help="Path to conformer saved h5 weights")
2731

2832
parser.add_argument("--blank", type=int, default=0,
2933
help="Path to conformer tflite")
3034

35+
parser.add_argument("--num_rnns", type=int, default=1,
36+
help="Number of RNN layers in prediction network")
37+
38+
parser.add_argument("--nstates", type=int, default=2,
39+
help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)")
40+
3141
parser.add_argument("--statesize", type=int, default=320,
32-
help="Path to conformer tflite")
42+
help="Size of RNN state in prediction network")
43+
44+
parser.add_argument("--device", type=int, default=0,
45+
help="Device's id to run test on")
46+
47+
parser.add_argument("--cpu", default=False, action="store_true",
48+
help="Whether to only use cpu")
3349

3450
args = parser.parse_args()
3551

36-
tflitemodel = tf.lite.Interpreter(model_path=args.tflite)
52+
setup_devices([args.device], cpu=args.cpu)
53+
54+
from tensorflow_asr.configs.config import Config
55+
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
56+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
57+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
58+
from tensorflow_asr.models.conformer import Conformer
59+
60+
config = Config(args.config, learning=False)
61+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
62+
text_featurizer = CharFeaturizer(config.decoder_config)
63+
64+
# build model
65+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
66+
conformer._build(speech_featurizer.shape)
67+
conformer.load_weights(args.saved, by_name=True)
68+
conformer.summary(line_length=120)
69+
conformer.add_featurizers(speech_featurizer, text_featurizer)
3770

3871
signal = read_raw_audio(args.filename)
72+
predicted = tf.constant(args.blank, dtype=tf.int32)
73+
states = tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32)
3974

40-
input_details = tflitemodel.get_input_details()
41-
output_details = tflitemodel.get_output_details()
42-
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
43-
tflitemodel.allocate_tensors()
44-
tflitemodel.set_tensor(input_details[0]["index"], signal)
45-
tflitemodel.set_tensor(
46-
input_details[1]["index"],
47-
tf.constant(args.blank, dtype=tf.int32)
48-
)
49-
tflitemodel.set_tensor(
50-
input_details[2]["index"],
51-
tf.zeros([1, 2, 1, args.statesize], dtype=tf.float32)
52-
)
53-
tflitemodel.invoke()
54-
hyp = tflitemodel.get_tensor(output_details[0]["index"])
75+
hyp, _, _ = conformer.recognize_tflite(signal, predicted, states)
5576

5677
print("".join([chr(u) for u in hyp]))

examples/demonstration/streaming_conformer.py renamed to examples/demonstration/streaming_tflite_conformer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ def int_or_str(text):
5757
parser.add_argument("--tflite", type=str, default=None,
5858
help="Path to conformer tflite")
5959

60+
parser.add_argument("--blank", type=int, default=0,
61+
help="Path to conformer tflite")
62+
63+
parser.add_argument("--num_rnns", type=int, default=1,
64+
help="Number of RNN layers in prediction network")
65+
66+
parser.add_argument("--nstates", type=int, default=2,
67+
help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)")
68+
69+
parser.add_argument("--statesize", type=int, default=320,
70+
help="Size of RNN state in prediction network")
71+
6072
args = parser.parse_args(remaining)
6173

6274
if args.blocksize == 0:
@@ -92,8 +104,8 @@ def recognize(signal, lastid, states):
92104
text = "".join([chr(u) for u in upoints])
93105
return text, lastid, states
94106

95-
lastid = np.zeros(shape=[], dtype=np.int32)
96-
states = np.zeros(shape=[1, 2, 1, 320], dtype=np.float32)
107+
lastid = args.blank * np.ones(shape=[], dtype=np.int32)
108+
states = np.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=np.float32)
97109
transcript = ""
98110

99111
while True:
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import tensorflow as tf
17+
18+
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
19+
20+
parser = argparse.ArgumentParser(prog="Conformer non streaming")
21+
22+
parser.add_argument("filename", metavar="FILENAME",
23+
help="Audio file to be played back")
24+
25+
parser.add_argument("--tflite", type=str, default=None,
26+
help="Path to conformer tflite")
27+
28+
parser.add_argument("--blank", type=int, default=0,
29+
help="Blank index")
30+
31+
parser.add_argument("--num_rnns", type=int, default=1,
32+
help="Number of RNN layers in prediction network")
33+
34+
parser.add_argument("--nstates", type=int, default=2,
35+
help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)")
36+
37+
parser.add_argument("--statesize", type=int, default=320,
38+
help="Size of RNN state in prediction network")
39+
40+
args = parser.parse_args()
41+
42+
tflitemodel = tf.lite.Interpreter(model_path=args.tflite)
43+
44+
signal = read_raw_audio(args.filename)
45+
46+
input_details = tflitemodel.get_input_details()
47+
output_details = tflitemodel.get_output_details()
48+
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
49+
tflitemodel.allocate_tensors()
50+
tflitemodel.set_tensor(input_details[0]["index"], signal)
51+
tflitemodel.set_tensor(
52+
input_details[1]["index"],
53+
tf.constant(args.blank, dtype=tf.int32)
54+
)
55+
tflitemodel.set_tensor(
56+
input_details[2]["index"],
57+
tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32)
58+
)
59+
tflitemodel.invoke()
60+
hyp = tflitemodel.get_tensor(output_details[0]["index"])
61+
62+
print("".join([chr(u) for u in hyp]))

tensorflow_asr/utils/metrics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from typing import Tuple
1416
import numpy as np
1517
import tensorflow as tf
1618
from nltk.metrics import distance
1719
from .utils import bytes_to_string
1820

1921

20-
def wer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor):
22+
def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
2123
"""Word Error Rate
2224
2325
Args:
@@ -43,7 +45,7 @@ def wer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor):
4345
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
4446

4547

46-
def cer(decode: np.ndarray, target: np.ndarray) -> (tf.Tensor, tf.Tensor):
48+
def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
4749
"""Character Error Rate
4850
4951
Args:

0 commit comments

Comments
 (0)