Skip to content

Commit 2a40a6c

Browse files
committed
fix(model): incorrect metrics
1 parent fed73be commit 2a40a6c

File tree

6 files changed

+51
-74
lines changed

6 files changed

+51
-74
lines changed

README.md

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ TensorFlowASR implements some automatic speech recognition architectures such as
3838
- [Baselines](#baselines)
3939
- [Publications](#publications)
4040
- [Installation](#installation)
41+
- [Installing from source (recommended)](#installing-from-source-recommended)
4142
- [Installing via PyPi](#installing-via-pypi)
42-
- [Installing from source](#installing-from-source)
4343
- [Running in a container](#running-in-a-container)
4444
- [Setup training and testing](#setup-training-and-testing)
4545
- [TFLite Convertion](#tflite-convertion)
@@ -59,42 +59,33 @@ TensorFlowASR implements some automatic speech recognition architectures such as
5959

6060
### Baselines
6161

62-
- **CTCModel** (End2end models using CTC Loss for training, currently supported DeepSpeech2, Jasper)
6362
- **Transducer Models** (End2end models using RNNT Loss for training, currently supported Conformer, ContextNet, Streaming Transducer)
63+
- **CTCModel** (End2end models using CTC Loss for training, currently supported DeepSpeech2, Jasper)
6464

6565
### Publications
6666

67-
- **Deep Speech 2** (Reference: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595))
68-
See [examples/deepspeech2](./examples/deepspeech2)
69-
- **Jasper** (Reference: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288))
70-
See [examples/jasper](./examples/jasper)
7167
- **Conformer Transducer** (Reference: [https://arxiv.org/abs/2005.08100](https://arxiv.org/abs/2005.08100))
7268
See [examples/conformer](./examples/conformer)
7369
- **Streaming Transducer** (Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621))
7470
See [examples/streaming_transducer](./examples/streaming_transducer)
7571
- **ContextNet** (Reference: [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191))
7672
See [examples/contextnet](./examples/contextnet)
73+
- **Deep Speech 2** (Reference: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595))
74+
See [examples/deepspeech2](./examples/deepspeech2)
75+
- **Jasper** (Reference: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288))
76+
See [examples/jasper](./examples/jasper)
7777

7878
## Installation
7979

8080
For training and testing, you should use `git clone` for installing necessary packages from other authors (`ctc_decoders`, `rnnt_loss`, etc.)
8181

82-
### Installing via PyPi
83-
84-
For tensorflow 2.3.x, run `pip3 install -U 'TensorFlowASR[tf2.3]'` or `pip3 install -U 'TensorFlowASR[tf2.3-gpu]'`
85-
86-
For tensorflow 2.4.x, run `pip3 install -U 'TensorFlowASR[tf2.4]'` or `pip3 install -U 'TensorFlowASR[tf2.4-gpu]'`
87-
88-
For tensorflow 2.5.x, run `pip3 install -U 'TensorFlowASR[tf2.5]'` or `pip3 install -U 'TensorFlowASR[tf2.5-gpu]'`
89-
90-
For tensorflow 2.6.x, run `pip3 install -U 'TensorFlowASR[tf2.6]'` or `pip3 install -U 'TensorFlowASR[tf2.6-gpu]'`
91-
92-
### Installing from source
82+
### Installing from source (recommended)
9383

9484
```bash
9585
git clone https://github.com/TensorSpeech/TensorFlowASR.git
9686
cd TensorFlowASR
97-
pip3 install -e '.[tf2.6]' # see other options in setup.py file
87+
# Tensorflow 2.x (with 2.x >= 2.3)
88+
pip3 install -e ".[tf2.x]" # or ".[tf2.x-gpu]"
9889
```
9990

10091
For anaconda3:
@@ -105,9 +96,18 @@ conda activate tfasr
10596
pip install -U tensorflow-gpu # upgrade to latest version of tensorflow
10697
git clone https://github.com/TensorSpeech/TensorFlowASR.git
10798
cd TensorFlowASR
108-
pip3 install -e '.[tf2.3]' # or '.[tf2.3-gpu]' or '.[tf2.4]' or '.[tf2.4-gpu]' or '.[tf2.5]' or '.[tf2.5-gpu]'
99+
# Tensorflow 2.x (with 2.x >= 2.3)
100+
pip3 install -e ".[tf2.x]" # or ".[tf2.x-gpu]"
109101
```
110102

103+
### Installing via PyPi
104+
105+
```bash
106+
# Tensorflow 2.x (with 2.x >= 2.3)
107+
pip3 install -U "TensorFlowASR[tf2.x]" # or pip3 install -U "TensorFlowASR[tf2.x-gpu]"
108+
```
109+
110+
111111
### Running in a container
112112

113113
```bash

examples/conformer/config.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ decoder_config:
3131
beam_width: 0
3232
norm_score: True
3333
corpus_files:
34-
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
34+
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
3535

3636
model_config:
3737
name: conformer
@@ -75,8 +75,8 @@ learning_config:
7575
num_masks: 1
7676
mask_factor: 27
7777
data_paths:
78-
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79-
tfrecords_dir: /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
78+
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79+
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
8080
shuffle: True
8181
cache: True
8282
buffer_size: 100
@@ -86,7 +86,7 @@ learning_config:
8686
eval_dataset_config:
8787
use_tf: True
8888
data_paths: null
89-
tfrecords_dir: /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
89+
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
9090
shuffle: False
9191
cache: True
9292
buffer_size: 100
@@ -113,13 +113,13 @@ learning_config:
113113
batch_size: 2
114114
num_epochs: 50
115115
checkpoint:
116-
filepath: /mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5
116+
filepath: /mnt/Miscellanea/Models/local/conformer/checkpoints/{epoch:02d}.h5
117117
save_best_only: True
118118
save_weights_only: True
119119
save_freq: epoch
120-
states_dir: /mnt/e/Models/local/conformer/states
120+
states_dir: /mnt/Miscellanea/Models/local/conformer/states
121121
tensorboard:
122-
log_dir: /mnt/e/Models/local/conformer/tensorboard
122+
log_dir: /mnt/Miscellanea/Models/local/conformer/tensorboard
123123
histogram_freq: 1
124124
write_graph: True
125125
write_images: True

examples/conformer/saved_model.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,38 +27,23 @@
2727
parser = argparse.ArgumentParser(prog="Conformer Testing")
2828

2929
parser.add_argument(
30-
"--config",
31-
type=str,
32-
default=DEFAULT_YAML,
33-
help="The file path of model configuration file",
30+
"--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file",
3431
)
3532

3633
parser.add_argument(
37-
"--h5",
38-
type=str,
39-
default=None,
40-
help="Path to saved h5 weights",
34+
"--h5", type=str, default=None, help="Path to saved h5 weights",
4135
)
4236

4337
parser.add_argument(
44-
"--sentence_piece",
45-
default=False,
46-
action="store_true",
47-
help="Whether to use `SentencePiece` model",
38+
"--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model",
4839
)
4940

5041
parser.add_argument(
51-
"--subwords",
52-
default=False,
53-
action="store_true",
54-
help="Use subwords",
42+
"--subwords", default=False, action="store_true", help="Use subwords",
5543
)
5644

5745
parser.add_argument(
58-
"--output_dir",
59-
type=str,
60-
default=None,
61-
help="Output directory for saved model",
46+
"--output_dir", type=str, default=None, help="Output directory for saved model",
6247
)
6348

6449
args = parser.parse_args()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
SoundFile==0.10.3.post1
22
tensorflow_datasets==4.4.0
33
nltk==3.6.4
4-
numpy==1.19.5
4+
numpy>=1.19.5
55
sentencepiece==0.1.96
66
tqdm==4.62.1
77
librosa==0.8.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
setuptools.setup(
3232
name="TensorFlowASR",
33-
version="1.0.3",
33+
version="1.0.2",
3434
author="Huy Le Nguyen",
3535
author_email="[email protected]",
3636
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/models/base_model.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,50 +41,42 @@ def save(
4141
)
4242

4343
def save_weights(
44-
self,
45-
filepath,
46-
overwrite=True,
47-
save_format=None,
48-
options=None,
44+
self, filepath, overwrite=True, save_format=None, options=None,
4945
):
5046
with file_util.save_file(filepath) as path:
5147
super().save_weights(filepath=path, overwrite=overwrite, save_format=save_format, options=options)
5248

5349
def load_weights(
54-
self,
55-
filepath,
56-
by_name=False,
57-
skip_mismatch=False,
58-
options=None,
50+
self, filepath, by_name=False, skip_mismatch=False, options=None,
5951
):
6052
with file_util.read_file(filepath) as path:
6153
super().load_weights(filepath=path, by_name=by_name, skip_mismatch=skip_mismatch, options=options)
6254

55+
@property
56+
def metrics(self):
57+
if not hasattr(self, "_tfasr_metrics"):
58+
self._tfasr_metrics = {}
59+
return list(self._tfasr_metrics.values())
60+
6361
def add_metric(
64-
self,
65-
metric: tf.keras.metrics.Metric,
62+
self, metric: tf.keras.metrics.Metric,
6663
):
67-
if not hasattr(self, "_metrics"):
68-
self._metrics = {}
69-
self._metrics[metric.name] = metric
64+
if not hasattr(self, "_tfasr_metrics"):
65+
self._tfasr_metrics = {}
66+
self._tfasr_metrics[metric.name] = metric
7067

7168
def make(self, *args, **kwargs):
7269
"""Custom function for building model (uses self.build so cannot overwrite that function)"""
7370
raise NotImplementedError()
7471

7572
def compile(
76-
self,
77-
loss,
78-
optimizer,
79-
run_eagerly=None,
80-
**kwargs,
73+
self, loss, optimizer, run_eagerly=None, **kwargs,
8174
):
8275
self.use_loss_scale = False
8376
if not env_util.has_devices("TPU"):
8477
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), "dynamic")
8578
self.use_loss_scale = True
86-
loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
87-
self.add_metric(loss_metric)
79+
self.add_metric(metric=tf.keras.metrics.Mean(name="loss", dtype=tf.float32))
8880
super().compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs)
8981

9082
# -------------------------------- STEP FUNCTIONS -------------------------------------
@@ -110,8 +102,8 @@ def train_step(self, batch):
110102
else:
111103
gradients = tape.gradient(loss, self.trainable_weights)
112104
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
113-
self._metrics["loss"].update_state(loss)
114-
return {m.name: m.result() for m in self._metrics.values()}
105+
self._tfasr_metrics["loss"].update_state(loss)
106+
return {m.name: m.result() for m in self.metrics}
115107

116108
def test_step(self, batch):
117109
"""
@@ -125,8 +117,8 @@ def test_step(self, batch):
125117
inputs, y_true = batch
126118
y_pred = self(inputs, training=False)
127119
loss = self.loss(y_true, y_pred)
128-
self._metrics["loss"].update_state(loss)
129-
return {m.name: m.result() for m in self._metrics.values()}
120+
self._tfasr_metrics["loss"].update_state(loss)
121+
return {m.name: m.result() for m in self.metrics}
130122

131123
def predict_step(self, batch):
132124
"""

0 commit comments

Comments
 (0)