Skip to content

Commit 6dd6ff3

Browse files
authored
Merge pull request #229 from TensorSpeech/fix/refactor
Bug fixings
2 parents d4e1fea + c87fca6 commit 6dd6ff3

36 files changed

+437
-527
lines changed

README.md

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ TensorFlowASR implements some automatic speech recognition architectures such as
3838
- [Baselines](#baselines)
3939
- [Publications](#publications)
4040
- [Installation](#installation)
41+
- [Installing from source (recommended)](#installing-from-source-recommended)
4142
- [Installing via PyPi](#installing-via-pypi)
42-
- [Installing from source](#installing-from-source)
4343
- [Running in a container](#running-in-a-container)
4444
- [Setup training and testing](#setup-training-and-testing)
4545
- [TFLite Convertion](#tflite-convertion)
@@ -59,42 +59,33 @@ TensorFlowASR implements some automatic speech recognition architectures such as
5959

6060
### Baselines
6161

62-
- **CTCModel** (End2end models using CTC Loss for training, currently supported DeepSpeech2, Jasper)
6362
- **Transducer Models** (End2end models using RNNT Loss for training, currently supported Conformer, ContextNet, Streaming Transducer)
63+
- **CTCModel** (End2end models using CTC Loss for training, currently supported DeepSpeech2, Jasper)
6464

6565
### Publications
6666

67-
- **Deep Speech 2** (Reference: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595))
68-
See [examples/deepspeech2](./examples/deepspeech2)
69-
- **Jasper** (Reference: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288))
70-
See [examples/jasper](./examples/jasper)
7167
- **Conformer Transducer** (Reference: [https://arxiv.org/abs/2005.08100](https://arxiv.org/abs/2005.08100))
7268
See [examples/conformer](./examples/conformer)
7369
- **Streaming Transducer** (Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621))
7470
See [examples/streaming_transducer](./examples/streaming_transducer)
7571
- **ContextNet** (Reference: [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191))
7672
See [examples/contextnet](./examples/contextnet)
73+
- **Deep Speech 2** (Reference: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595))
74+
See [examples/deepspeech2](./examples/deepspeech2)
75+
- **Jasper** (Reference: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288))
76+
See [examples/jasper](./examples/jasper)
7777

7878
## Installation
7979

8080
For training and testing, you should use `git clone` for installing necessary packages from other authors (`ctc_decoders`, `rnnt_loss`, etc.)
8181

82-
### Installing via PyPi
83-
84-
For tensorflow 2.3.x, run `pip3 install -U 'TensorFlowASR[tf2.3]'` or `pip3 install -U 'TensorFlowASR[tf2.3-gpu]'`
85-
86-
For tensorflow 2.4.x, run `pip3 install -U 'TensorFlowASR[tf2.4]'` or `pip3 install -U 'TensorFlowASR[tf2.4-gpu]'`
87-
88-
For tensorflow 2.5.x, run `pip3 install -U 'TensorFlowASR[tf2.5]'` or `pip3 install -U 'TensorFlowASR[tf2.5-gpu]'`
89-
90-
For tensorflow 2.6.x, run `pip3 install -U 'TensorFlowASR[tf2.6]'` or `pip3 install -U 'TensorFlowASR[tf2.6-gpu]'`
91-
92-
### Installing from source
82+
### Installing from source (recommended)
9383

9484
```bash
9585
git clone https://github.com/TensorSpeech/TensorFlowASR.git
9686
cd TensorFlowASR
97-
pip3 install -e '.[tf2.6]' # see other options in setup.py file
87+
# Tensorflow 2.x (with 2.x >= 2.3)
88+
pip3 install -e ".[tf2.x]" # or ".[tf2.x-gpu]"
9889
```
9990

10091
For anaconda3:
@@ -105,9 +96,18 @@ conda activate tfasr
10596
pip install -U tensorflow-gpu # upgrade to latest version of tensorflow
10697
git clone https://github.com/TensorSpeech/TensorFlowASR.git
10798
cd TensorFlowASR
108-
pip3 install '.[tf2.3]' # or '.[tf2.3-gpu]' or '.[tf2.4]' or '.[tf2.4-gpu]' or '.[tf2.5]' or '.[tf2.5-gpu]'
99+
# Tensorflow 2.x (with 2.x >= 2.3)
100+
pip3 install -e ".[tf2.x]" # or ".[tf2.x-gpu]"
109101
```
110102

103+
### Installing via PyPi
104+
105+
```bash
106+
# Tensorflow 2.x (with 2.x >= 2.3)
107+
pip3 install -U "TensorFlowASR[tf2.x]" # or pip3 install -U "TensorFlowASR[tf2.x-gpu]"
108+
```
109+
110+
111111
### Running in a container
112112

113113
```bash

examples/conformer/config.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ decoder_config:
3131
beam_width: 0
3232
norm_score: True
3333
corpus_files:
34-
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
34+
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
3535

3636
model_config:
3737
name: conformer
@@ -75,8 +75,8 @@ learning_config:
7575
num_masks: 1
7676
mask_factor: 27
7777
data_paths:
78-
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79-
tfrecords_dir: /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
78+
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79+
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
8080
shuffle: True
8181
cache: True
8282
buffer_size: 100
@@ -86,7 +86,7 @@ learning_config:
8686
eval_dataset_config:
8787
use_tf: True
8888
data_paths: null
89-
tfrecords_dir: /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
89+
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
9090
shuffle: False
9191
cache: True
9292
buffer_size: 100
@@ -113,13 +113,13 @@ learning_config:
113113
batch_size: 2
114114
num_epochs: 50
115115
checkpoint:
116-
filepath: /mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5
117-
save_best_only: True
116+
filepath: /mnt/Miscellanea/Models/local/conformer/checkpoints/{epoch:02d}.h5
117+
save_best_only: False
118118
save_weights_only: True
119119
save_freq: epoch
120-
states_dir: /mnt/e/Models/local/conformer/states
120+
states_dir: /mnt/Miscellanea/Models/local/conformer/states
121121
tensorboard:
122-
log_dir: /mnt/e/Models/local/conformer/tensorboard
122+
log_dir: /mnt/Miscellanea/Models/local/conformer/tensorboard
123123
histogram_freq: 1
124124
write_graph: True
125125
write_images: True

examples/conformer/saved_model.py

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,15 @@
2626

2727
parser = argparse.ArgumentParser(prog="Conformer Testing")
2828

29-
parser.add_argument(
30-
"--config",
31-
type=str,
32-
default=DEFAULT_YAML,
33-
help="The file path of model configuration file",
34-
)
35-
36-
parser.add_argument(
37-
"--h5",
38-
type=str,
39-
default=None,
40-
help="Path to saved h5 weights",
41-
)
42-
43-
parser.add_argument(
44-
"--sentence_piece",
45-
default=False,
46-
action="store_true",
47-
help="Whether to use `SentencePiece` model",
48-
)
49-
50-
parser.add_argument(
51-
"--subwords",
52-
default=False,
53-
action="store_true",
54-
help="Use subwords",
55-
)
56-
57-
parser.add_argument(
58-
"--output_dir",
59-
type=str,
60-
default=None,
61-
help="Output directory for saved model",
62-
)
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
30+
31+
parser.add_argument("--h5", type=str, default=None, help="Path to saved h5 weights")
32+
33+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
34+
35+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
36+
37+
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for saved model")
6338

6439
args = parser.parse_args()
6540

@@ -94,23 +69,14 @@
9469
conformer.add_featurizers(speech_featurizer, text_featurizer)
9570

9671

97-
class aModule(tf.Module):
98-
def __init__(self, model):
99-
super().__init__()
100-
self.model = model
101-
102-
@tf.function(
103-
input_signature=[
104-
{
105-
"inputs": tf.TensorSpec(shape=[None, None, 80, 1], dtype=tf.float32, name="inputs"),
106-
"inputs_length": tf.TensorSpec(shape=[None], dtype=tf.int32, name="inputs_length"),
107-
}
108-
]
109-
)
110-
def pred(self, input_batch):
111-
result = self.model.recognize(input_batch)
112-
return {"ASR": result}
72+
# TODO: Support saved model conversion
73+
# class ConformerModule(tf.Module):
74+
# def __init__(self, model: Conformer, name=None):
75+
# super().__init__(name=name)
76+
# self.model = model
77+
# self.pred = model.make_tflite_function()
11378

11479

115-
module = aModule(conformer)
116-
tf.saved_model.save(module, args.output_dir, signatures={"serving_default": module.pred})
80+
# model = ConformerModule(model=conformer)
81+
# tf.saved_model.save(model, args.output_dir)
82+
conformer.save(args.output_dir, include_optimizer=False, save_format="tf")

examples/contextnet/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ learning_config:
247247
num_epochs: 20
248248
checkpoint:
249249
filepath: /mnt/e/Models/local/contextnet/checkpoints/{epoch:02d}.h5
250-
save_best_only: True
250+
save_best_only: False
251251
save_weights_only: True
252252
save_freq: epoch
253253
states_dir: /mnt/e/Models/local/contextnet/states

examples/deepspeech2/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ learning_config:
9191
num_epochs: 20
9292
checkpoint:
9393
filepath: /mnt/e/Models/local/deepspeech2/checkpoints/{epoch:02d}.h5
94-
save_best_only: True
94+
save_best_only: False
9595
save_weights_only: True
9696
save_freq: epoch
9797
states_dir: /mnt/e/Models/local/deepspeech2/states

examples/jasper/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ learning_config:
9898
num_epochs: 20
9999
checkpoint:
100100
filepath: /mnt/e/Models/local/jasper/checkpoints/{epoch:02d}.h5
101-
save_best_only: True
101+
save_best_only: False
102102
save_weights_only: True
103103
save_freq: epoch
104104
states_dir: /mnt/e/Models/local/jasper/states

examples/rnn_transducer/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ learning_config:
106106
num_epochs: 20
107107
checkpoint:
108108
filepath: /mnt/e/Models/local/rnn_transducer/checkpoints/{epoch:02d}.h5
109-
save_best_only: True
109+
save_best_only: False
110110
save_weights_only: True
111111
save_freq: epoch
112112
states_dir: /mnt/e/Models/local/rnn_transducer/states

notebooks/conformer.ipynb

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,4 @@
11
{
2-
"metadata": {
3-
"language_info": {
4-
"codemirror_mode": {
5-
"name": "ipython",
6-
"version": 3
7-
},
8-
"file_extension": ".py",
9-
"mimetype": "text/x-python",
10-
"name": "python",
11-
"nbconvert_exporter": "python",
12-
"pygments_lexer": "ipython3",
13-
"version": "3.8.8-final"
14-
},
15-
"orig_nbformat": 2,
16-
"kernelspec": {
17-
"name": "python388jvsc74a57bd045f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f",
18-
"display_name": "Python 3.8.8 64-bit ('tfo': conda)"
19-
},
20-
"metadata": {
21-
"interpreter": {
22-
"hash": "45f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
23-
}
24-
}
25-
},
26-
"nbformat": 4,
27-
"nbformat_minor": 2,
282
"cells": [
293
{
304
"cell_type": "code",
@@ -137,7 +111,7 @@
137111
" \"num_epochs\": 50,\n",
138112
" \"checkpoint\": {\n",
139113
" \"filepath\": \"/mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5\",\n",
140-
" \"save_best_only\": True,\n",
114+
" \"save_best_only\": False,\n",
141115
" \"save_weights_only\": True,\n",
142116
" \"save_freq\": \"epoch\",\n",
143117
" },\n",
@@ -265,5 +239,31 @@
265239
"outputs": [],
266240
"source": []
267241
}
268-
]
269-
}
242+
],
243+
"metadata": {
244+
"kernelspec": {
245+
"display_name": "Python 3.8.8 64-bit ('tfo': conda)",
246+
"name": "python388jvsc74a57bd045f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
247+
},
248+
"language_info": {
249+
"codemirror_mode": {
250+
"name": "ipython",
251+
"version": 3
252+
},
253+
"file_extension": ".py",
254+
"mimetype": "text/x-python",
255+
"name": "python",
256+
"nbconvert_exporter": "python",
257+
"pygments_lexer": "ipython3",
258+
"version": "3.8.8-final"
259+
},
260+
"metadata": {
261+
"interpreter": {
262+
"hash": "45f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
263+
}
264+
},
265+
"orig_nbformat": 2
266+
},
267+
"nbformat": 4,
268+
"nbformat_minor": 2
269+
}

notebooks/contextnet.ipynb

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,4 @@
11
{
2-
"metadata": {
3-
"language_info": {
4-
"codemirror_mode": {
5-
"name": "ipython",
6-
"version": 3
7-
},
8-
"file_extension": ".py",
9-
"mimetype": "text/x-python",
10-
"name": "python",
11-
"nbconvert_exporter": "python",
12-
"pygments_lexer": "ipython3",
13-
"version": "3.8.8-final"
14-
},
15-
"orig_nbformat": 2,
16-
"kernelspec": {
17-
"name": "python388jvsc74a57bd045f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f",
18-
"display_name": "Python 3.8.8 64-bit ('tfo': conda)"
19-
},
20-
"metadata": {
21-
"interpreter": {
22-
"hash": "45f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
23-
}
24-
}
25-
},
26-
"nbformat": 4,
27-
"nbformat_minor": 2,
282
"cells": [
293
{
304
"cell_type": "code",
@@ -308,7 +282,7 @@
308282
" \"num_epochs\": 20,\n",
309283
" \"checkpoint\": {\n",
310284
" \"filepath\": \"/mnt/e/Models/local/contextnet/checkpoints/{epoch:02d}.h5\",\n",
311-
" \"save_best_only\": True,\n",
285+
" \"save_best_only\": False,\n",
312286
" \"save_weights_only\": True,\n",
313287
" \"save_freq\": \"epoch\",\n",
314288
" },\n",
@@ -429,5 +403,31 @@
429403
")"
430404
]
431405
}
432-
]
433-
}
406+
],
407+
"metadata": {
408+
"kernelspec": {
409+
"display_name": "Python 3.8.8 64-bit ('tfo': conda)",
410+
"name": "python388jvsc74a57bd045f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
411+
},
412+
"language_info": {
413+
"codemirror_mode": {
414+
"name": "ipython",
415+
"version": 3
416+
},
417+
"file_extension": ".py",
418+
"mimetype": "text/x-python",
419+
"name": "python",
420+
"nbconvert_exporter": "python",
421+
"pygments_lexer": "ipython3",
422+
"version": "3.8.8-final"
423+
},
424+
"metadata": {
425+
"interpreter": {
426+
"hash": "45f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
427+
}
428+
},
429+
"orig_nbformat": 2
430+
},
431+
"nbformat": 4,
432+
"nbformat_minor": 2
433+
}

0 commit comments

Comments
 (0)