Skip to content

Commit db66008

Browse files
committed
fix: tflite, initial states, results
1 parent 0c19911 commit db66008

File tree

15 files changed

+113
-71
lines changed

15 files changed

+113
-71
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ disable=too-few-public-methods,
125125
abstract-method,
126126
too-many-ancestors,
127127
import-outside-toplevel,
128+
too-many-positional-arguments,
128129

129130
# Enable the message, report, category or checker with the given id(s). You can
130131
# either give multiple identifier separated by comma (,) or put this option

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ See [tflite_convertion](./docs/tutorials/tflite.md)
147147

148148
## Pretrained Models
149149

150-
Go to [drive](https://drive.google.com/drive/folders/1BD0AK30n8hc-yR28C5FW3LqzZxtLOQfl?usp=sharing)
150+
See the results on each example folder, e.g. [./examples/models//transducer/conformer/results/sentencepiece/README.md](./examples/models//transducer/conformer/results/sentencepiece/README.md)
151151

152152
## Corpus Sources
153153

@@ -165,6 +165,7 @@ Go to [drive](https://drive.google.com/drive/folders/1BD0AK30n8hc-yR28C5FW3LqzZx
165165
| Vivos | [https://ailab.hcmus.edu.vn/vivos](https://www.kaggle.com/datasets/kynthesis/vivos-vietnamese-speech-corpus-for-asr) | 15h |
166166
| InfoRe Technology 1 | [InfoRe1 (passwd: BroughtToYouByInfoRe)](https://files.huylenguyen.com/datasets/infore/25hours.zip) | 25h |
167167
| InfoRe Technology 2 (used in VLSP2019) | [InfoRe2 (passwd: BroughtToYouByInfoRe)](https://files.huylenguyen.com/datasets/infore/audiobooks.zip) | 415h |
168+
| VieitBud500 | [https://huggingface.co/datasets/linhtran92/viet_bud500](https://huggingface.co/datasets/linhtran92/viet_bud500) | 500h |
168169

169170
## How to contribute
170171

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
- [\[English\] LibriSpeech](#english-librispeech)
22
- [I. Small + SentencePiece 256](#i-small--sentencepiece-256)
3+
- [II. Small + Streaming + SentencePiece 256](#ii-small--streaming--sentencepiece-256)
34

45
# [English] LibriSpeech
56

67
## I. Small + SentencePiece 256
78

8-
| Category | Description |
9-
| :---------------- | :--------------------------------------------------------- |
10-
| Config | [small.yml.j2](../../small.yml.j2) |
11-
| Tensorflow | **2.18.0** |
12-
| Device | Google Cloud TPUs v4-8 |
13-
| Mixed Precision | strict |
14-
| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) |
15-
| Max Epochs | 450 |
9+
| Category | Description |
10+
| :---------------- | :--------------------------------------------------------------------------------------- |
11+
| Config | [small.yml.j2](../../small.yml.j2) |
12+
| Tensorflow | **2.18.0** |
13+
| Device | Google Cloud TPUs v4-8 |
14+
| Mixed Precision | strict |
15+
| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) |
16+
| Max Epochs | 450 |
17+
| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-ctc/tensorFlow2/v3-small) |
1618

1719
**Config:**
1820

@@ -30,17 +32,18 @@
3032
| 170 | test-clean | greedy | 0.0967171 | 0.031954 | 0.0958403 | 0.168307 | 0.831693 |
3133
| 170 | test-other | greedy | 0.201612 | 0.0812955 | 0.197415 | 0.330207 | 0.669793 |
3234

33-
<!--
35+
3436
## II. Small + Streaming + SentencePiece 256
3537

36-
| Category | Description |
37-
| :---------------- | :--------------------------------------------------------- |
38-
| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) |
39-
| Tensorflow | **2.18.0** |
40-
| Device | Google Cloud TPUs v4-8 |
41-
| Mixed Precision | strict |
42-
| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) |
43-
| Max Epochs | 450 |
38+
| Category | Description |
39+
| :---------------- | :------------------------------------------------------------------------------------------------- |
40+
| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) |
41+
| Tensorflow | **2.18.0** |
42+
| Device | Google Cloud TPUs v4-8 |
43+
| Mixed Precision | strict |
44+
| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) |
45+
| Max Epochs | 450 |
46+
| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-ctc/tensorFlow2/v3-small-streaming) |
4447

4548
**Config:**
4649

@@ -51,8 +54,28 @@
5154
{{config}}
5255
```
5356

57+
**Tensorboard:**
58+
59+
<table>
60+
<tr>
61+
<td align="center">
62+
<img src="./figs/librispeech-small-streaming-epoch-loss.jpg" width="200px"><br>
63+
<sub><strong>Epoch Loss</strong></sub>
64+
</td>
65+
<td align="center">
66+
<img src="./figs/librispeech-small-streaming-batch-loss.jpg" width="200px"><br>
67+
<sub><strong>Batch Loss</strong></sub>
68+
</td>
69+
<td align="center">
70+
<img src="./figs/librispeech-small-streaming-lr.jpg " width="200px"><br>
71+
<sub><strong>Learning Rate</strong></sub>
72+
</td>
73+
</tr>
74+
</table>
75+
5476
**Results:**
5577

56-
| Epoch | Dataset | decoding | wer | cer | mer | wil | wip |
57-
| :---- | :------ | :------- | :--- | :--- | :--- | :--- | :--- |
58-
-->
78+
| Epoch | Dataset | decoding | wer | cer | mer | wil | wip |
79+
| :---- | :--------- | :------- | :-------- | :-------- | :-------- | :------ | :------ |
80+
| 60 | test-clean | greedy | 0.0848106 | 0.0286257 | 0.0841686 | 0.14896 | 0.85104 |
81+
| 60 | test-other | greedy | 0.217221 | 0.0913044 | 0.213409 | 0.3555 | 0.6445 |
70.2 KB
Loading
64.2 KB
Loading
70.6 KB
Loading

examples/models/transducer/conformer/results/sentencepiece/README.md

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010

1111
## I. Small + SentencePiece 1k
1212

13-
| Category | Description |
14-
| :---------------- | :--------------------------------------------------------- |
15-
| Config | [small.yml.j2](../../small.yml.j2) |
16-
| Tensorflow | **2.18.0** |
17-
| Device | Google Cloud TPUs v4-8 |
18-
| Mixed Precision | strict |
19-
| Global Batch Size | 4 * 4 * 8 = 128 (as 4 TPUs, 8 Gradient Accumulation Steps) |
20-
| Max Epochs | 300 |
13+
| Category | Description |
14+
| :---------------- | :---------------------------------------------------------------------------------------------- |
15+
| Config | [small.yml.j2](../../small.yml.j2) |
16+
| Tensorflow | **2.18.0** |
17+
| Device | Google Cloud TPUs v4-8 |
18+
| Mixed Precision | strict |
19+
| Global Batch Size | 4 * 4 * 8 = 128 (as 4 TPUs, 8 Gradient Accumulation Steps) |
20+
| Max Epochs | 300 |
21+
| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-transducer/tensorFlow2/v3-small) |
2122

2223
**Config:**
2324

@@ -37,14 +38,15 @@
3738

3839
## II. Small + Streaming + SentencePiece 1k
3940

40-
| Category | Description |
41-
| :---------------- | :--------------------------------------------------------- |
42-
| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) |
43-
| Tensorflow | **2.18.0** |
44-
| Device | Google Cloud TPUs v4-8 |
45-
| Mixed Precision | strict |
46-
| Global Batch Size | 4 * 4 * 8 = 128 (as 4 TPUs, 8 Gradient Accumulation Steps) |
47-
| Max Epochs | 300 |
41+
| Category | Description |
42+
| :---------------- | :-------------------------------------------------------------------------------------------------------- |
43+
| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) |
44+
| Tensorflow | **2.18.0** |
45+
| Device | Google Cloud TPUs v4-8 |
46+
| Mixed Precision | strict |
47+
| Global Batch Size | 4 * 4 * 8 = 128 (as 4 TPUs, 8 Gradient Accumulation Steps) |
48+
| Max Epochs | 300 |
49+
| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-transducer/tensorFlow2/v3-small-streaming) |
4850

4951
**Config:**
5052

@@ -57,25 +59,26 @@
5759

5860
**Results:**
5961

60-
| Epoch | Dataset | decoding | wer | cer | mer | wil | wip |
61-
| :---- | :--------- | :------- | :------- | :-------- | :------- | :------- | :------- |
62-
| 45 | test-clean | greedy | 0.110564 | 0.0460022 | 0.109064 | 0.186109 | 0.813891 |
63-
| 45 | test-other | greedy | 0.267772 | 0.139369 | 0.260952 | 0.417361 | 0.582639 |
62+
| Epoch | Dataset | decoding | wer | cer | mer | wil | wip |
63+
| :---- | :--------- | :------- | :-------- | :-------- | :-------- | :------- | :------- |
64+
| 45 | test-clean | greedy | 0.0797322 | 0.0312862 | 0.0790049 | 0.137228 | 0.862772 |
65+
| 45 | test-other | greedy | 0.211872 | 0.104173 | 0.207305 | 0.341269 | 0.658731 |
6466

6567
<!-- ----------------------------------------------------- VN ------------------------------------------------------ -->
6668

6769
# [Vietnamese] VietBud500
6870

6971
## I. Small + Streaming + SentencePiece 1k
7072

71-
| Category | Description |
72-
| :---------------- | :--------------------------------------------------------- |
73-
| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) |
74-
| Tensorflow | **2.18.0** |
75-
| Device | Google Cloud TPUs v4-8 |
76-
| Mixed Precision | strict |
77-
| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) |
78-
| Max Epochs | 300 |
73+
| Category | Description |
74+
| :---------------- | :---------------------------------------------------------------------------------------------------------------- |
75+
| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) |
76+
| Tensorflow | **2.18.0** |
77+
| Device | Google Cloud TPUs v4-8 |
78+
| Mixed Precision | strict |
79+
| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) |
80+
| Max Epochs | 300 |
81+
| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-vietbud500-conformer-transducer/tensorFlow2/small-streaming) |
7982

8083
**Config:**
8184

@@ -109,6 +112,4 @@
109112

110113
| Epoch | decoding | wer | cer | mer | wil | wip |
111114
| :---- | :------- | :------- | :------- | :------ | :------- | :------- |
112-
| 52 | greedy | 0.053723 | 0.034548 | 0.05362 | 0.086421 | 0.913579 |
113-
114-
**Pretrained Model**: [Link](https://www.kaggle.com/models/lordh9072/tfasr-vietbud500-conformer-transducer/tensorFlow2/small-streaming)
115+
| 52 | greedy | 0.053723 | 0.034548 | 0.05362 | 0.086421 | 0.913579 |

tensorflow_asr/models/base_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,10 @@ def get_initial_tokens(self, batch_size=1):
317317
return tf.ones([batch_size, 1], dtype=tf.int32) * self.tokenizer.blank
318318

319319
def get_initial_encoder_states(self, batch_size=1):
320-
return None
320+
return []
321321

322322
def get_initial_decoder_states(self, batch_size=1):
323-
return None
323+
return []
324324

325325
def recognize(self, inputs: schemas.PredictInput, **kwargs) -> schemas.PredictOutput:
326326
"""Greedy decoding function that used in self.predict_step"""
@@ -351,8 +351,8 @@ def tflite_func(inputs: schemas.PredictInput):
351351
inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32),
352352
inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32),
353353
previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)),
354-
previous_encoder_states=tf.TensorSpec.from_tensor(self.get_initial_encoder_states(batch_size)),
355-
previous_decoder_states=tf.TensorSpec.from_tensor(self.get_initial_decoder_states(batch_size)),
354+
previous_encoder_states=tf.nest.map_structure(tf.TensorSpec.from_tensor, self.get_initial_encoder_states(batch_size)),
355+
previous_decoder_states=tf.nest.map_structure(tf.TensorSpec.from_tensor, self.get_initial_decoder_states(batch_size)),
356356
)
357357

358358
return tf.function(

tensorflow_asr/models/ctc/base_ctc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def call_next(
9292
return outputs, outputs_length, next_encoder_states, next_decoder_states
9393

9494
def get_initial_encoder_states(self, batch_size=1):
95-
return None
95+
return []
9696

9797
def get_initial_decoder_states(self, batch_size=1):
98-
return None
98+
return []
9999

100100
# -------------------------------- GREEDY -------------------------------------
101101

tensorflow_asr/models/encoders/conformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
""" http://arxiv.org/abs/2005.08100 """
15+
"""http://arxiv.org/abs/2005.08100"""
1616

1717
from tensorflow_asr import keras, tf
1818
from tensorflow_asr.models.activations.glu import GLU
@@ -21,6 +21,7 @@
2121
from tensorflow_asr.models.layers.multihead_attention import MultiHeadAttention, MultiHeadRelativeAttention
2222
from tensorflow_asr.models.layers.positional_encoding import RelativeSinusoidalPositionalEncoding, SinusoidalPositionalEncoding
2323
from tensorflow_asr.models.layers.residual import Residual
24+
from tensorflow_asr.utils import data_util
2425

2526
L2 = keras.regularizers.l2(1e-6)
2627

@@ -664,7 +665,9 @@ def __init__(
664665
self.content_attention_bias, self.positional_attention_bias = None, None
665666

666667
def get_initial_state(self, batch_size: int):
667-
return [block.get_initial_state(batch_size) for block in self.conformer_blocks]
668+
states = [block.get_initial_state(batch_size) for block in self.conformer_blocks]
669+
states = [s for s in states if s is not None]
670+
return states
668671

669672
def call(
670673
self,
@@ -684,7 +687,7 @@ def call(
684687
(outputs, relative_position_encoding),
685688
content_attention_bias=self.content_attention_bias,
686689
positional_attention_bias=self.positional_attention_bias,
687-
initial_state=None if initial_state is None else initial_state[i],
690+
initial_state=data_util.get(initial_state, i, None),
688691
training=training,
689692
use_causal_mask=self._use_attention_causal_mask,
690693
use_auto_mask=self._use_attention_auto_mask,

0 commit comments

Comments
 (0)