Skip to content

Commit 8c3b653

Browse files
committed
feat: add conformer small no decay result
1 parent 692199c commit 8c3b653

File tree

8 files changed

+135
-40
lines changed

8 files changed

+135
-40
lines changed

examples/configs/librispeech/data.yml.j2

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,28 @@ data_config:
3030
metadata: {{metadata}}
3131
indefinite: True
3232

33-
test_dataset_config:
34-
enabled: True
35-
sample_rate: 16000
36-
data_paths:
37-
- {{datadir}}/test-clean/transcripts.tsv
38-
tfrecords_dir: {{datadir}}/tfrecords
39-
shuffle: False
40-
cache: False
41-
buffer_size: null
42-
drop_remainder: False
43-
stage: test
44-
indefinite: False
33+
test_dataset_configs:
34+
- name: test-clean
35+
enabled: True
36+
sample_rate: 16000
37+
data_paths:
38+
- {{datadir}}/test-clean/transcripts.tsv
39+
tfrecords_dir: {{datadir}}/tfrecords
40+
shuffle: False
41+
cache: False
42+
buffer_size: null
43+
drop_remainder: False
44+
stage: test
45+
indefinite: False
46+
- name: test-other
47+
enabled: True
48+
sample_rate: 16000
49+
data_paths:
50+
- {{datadir}}/test-other/transcripts.tsv
51+
tfrecords_dir: {{datadir}}/tfrecords
52+
shuffle: False
53+
cache: False
54+
buffer_size: null
55+
drop_remainder: False
56+
stage: test
57+
indefinite: False

examples/models/transducer/conformer/results/sentencepiece/README.md

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
- [2. Batch Loss](#2-batch-loss)
66
- [Training Learning Rate](#training-learning-rate)
77
- [Results](#results)
8+
- [SentencePiece 1k + Small + LibriSpeech + Without Weight Decay](#sentencepiece-1k--small--librispeech--without-weight-decay)
9+
- [Training Loss](#training-loss-1)
10+
- [1. Epoch Loss](#1-epoch-loss-1)
11+
- [2. Batch Loss](#2-batch-loss-1)
12+
- [Training Learning Rate](#training-learning-rate-1)
13+
- [Results](#results-1)
814

915

1016
# SentencePiece 1k + Small + LibriSpeech
@@ -17,7 +23,6 @@
1723
| Device | Google Colab TPUs |
1824
| Global Batch Size | 2 * 16 * 8 = 256 (as 8 TPUs) |
1925
| Max Epochs | 300 |
20-
| Training time | |
2126

2227

2328
### Training Loss
@@ -64,3 +69,60 @@ Pretrain Model here: [link]()
6469
},
6570
]
6671
```
72+
73+
# SentencePiece 1k + Small + LibriSpeech + Without Weight Decay
74+
75+
76+
| Category | Description |
77+
| :---------------- | :--------------------------------------------------- |
78+
| Config | [small-no-decay.yml.j2](../../small-no-decay.yml.j2) |
79+
| Tensorflow | **2.13.x** |
80+
| Device | Google Colab TPUs |
81+
| Global Batch Size | 2 * 16 * 8 = 256 (as 8 TPUs) |
82+
| Max Epochs | 300 |
83+
84+
85+
### Training Loss
86+
87+
#### 1. Epoch Loss
88+
89+
![Epoch Loss](./figs/conformer-small-no-decay-sp1k-epoch-loss.svg)
90+
91+
#### 2. Batch Loss
92+
93+
![Batch Loss](./figs/conformer-small-no-decay-sp1k-batch-loss.svg)
94+
95+
### Training Learning Rate
96+
97+
![Learning Rate](./figs/conformer-small-no-decay-sp1k-lr.svg)
98+
99+
100+
### Results
101+
102+
Pretrain Model here: [link]()
103+
104+
```json
105+
[
106+
{
107+
"epoch": 115,
108+
"test-clean": {
109+
"greedy": {
110+
"wer": 0.06327982349360925,
111+
"cer": 0.02412176322239193,
112+
"mer": 0.06283642132698737,
113+
"wil": 0.110402410864341,
114+
"wip": 0.889597589135659
115+
}
116+
},
117+
"test-other": {
118+
"greedy": {
119+
"wer": 0.15083201192136483,
120+
"cer": 0.07265414763270005,
121+
"mer": 0.14853347882527798,
122+
"wil": 0.25123406103539114,
123+
"wip": 0.7487659389646089
124+
}
125+
}
126+
},
127+
]
128+
```

examples/models/transducer/conformer/results/sentencepiece/figs/conformer-small-no-decay-sp1k-batch-loss.svg

Lines changed: 1 addition & 0 deletions
Loading

examples/models/transducer/conformer/results/sentencepiece/figs/conformer-small-no-decay-sp1k-epoch-loss.svg

Lines changed: 1 addition & 0 deletions
Loading

examples/models/transducer/conformer/results/sentencepiece/figs/conformer-small-no-decay-sp1k-lr.svg

Lines changed: 1 addition & 0 deletions
Loading

examples/test.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,17 @@ def main(
2929
config_path: str,
3030
dataset_type: str,
3131
datadir: str,
32+
outputdir: str,
3233
h5: str = None,
3334
mxp: str = "none",
3435
bs: int = 1,
3536
device: int = 0,
3637
cpu: bool = False,
3738
jit_compile: bool = False,
38-
output: str = "test.tsv",
3939
repodir: str = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")),
4040
):
41-
assert h5 and output
42-
output = file_util.preprocess_paths(output)
41+
outputdir = file_util.preprocess_paths(outputdir, isdir=True)
42+
checkpoint_name = os.path.splitext(os.path.basename(h5))[0]
4343

4444
env_util.setup_seed()
4545
env_util.setup_devices([device], cpu=cpu)
@@ -50,34 +50,41 @@ def main(
5050

5151
tokenizer = tokenizers.get(config)
5252

53-
test_dataset = datasets.get(tokenizer=tokenizer, dataset_config=config.data_config.test_dataset_config, dataset_type=dataset_type)
54-
test_data_loader = test_dataset.create(batch_size)
55-
5653
model: BaseModel = tf.keras.models.model_from_config(config.model_config)
5754
model.tokenizer = tokenizer
5855
model.make(batch_size=batch_size)
5956
model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5), skip_mismatch=False)
6057
model.jit_compile = jit_compile
6158
model.summary()
6259

63-
overwrite = True
64-
if tf.io.gfile.exists(output):
65-
while overwrite not in ["yes", "no"]:
66-
overwrite = input(f"File {output} exists, overwrite? (yes/no): ").lower()
67-
overwrite = overwrite == "yes"
68-
69-
if overwrite:
70-
with file_util.save_file(output) as output_file_path:
71-
model.predict(
72-
test_data_loader,
73-
verbose=1,
74-
callbacks=[
75-
PredictLogger(test_dataset=test_dataset, output_file_path=output_file_path),
76-
],
77-
)
78-
79-
evaluation_outputs = app_util.evaluate_hypotheses(output)
80-
logger.info(json.dumps(evaluation_outputs, indent=2))
60+
for test_data_config in config.data_config.test_dataset_configs:
61+
if not test_data_config.name:
62+
raise ValueError("Test dataset name must be provided")
63+
logger.info(f"Testing dataset: {test_data_config.name}")
64+
65+
output = os.path.join(outputdir, f"{test_data_config.name}-{checkpoint_name}.tsv")
66+
67+
test_dataset = datasets.get(tokenizer=tokenizer, dataset_config=test_data_config, dataset_type=dataset_type)
68+
test_data_loader = test_dataset.create(batch_size)
69+
70+
overwrite = True
71+
if tf.io.gfile.exists(output):
72+
while overwrite not in ["yes", "no"]:
73+
overwrite = input(f"File {output} exists, overwrite? (yes/no): ").lower()
74+
overwrite = overwrite == "yes"
75+
76+
if overwrite:
77+
with file_util.save_file(output) as output_file_path:
78+
model.predict(
79+
test_data_loader,
80+
verbose=1,
81+
callbacks=[
82+
PredictLogger(test_dataset=test_dataset, output_file_path=output_file_path),
83+
],
84+
)
85+
86+
evaluation_outputs = app_util.evaluate_hypotheses(output)
87+
logger.info(json.dumps(evaluation_outputs, indent=2))
8188

8289

8390
if __name__ == "__main__":

tensorflow_asr/configs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class DatasetConfig:
6666
def __init__(self, config: dict = None):
6767
if not config:
6868
config = {}
69+
self.name: str = config.pop("name", "")
6970
self.enabled: bool = config.pop("enabled", True)
7071
self.stage: str = config.pop("stage", None)
7172
self.data_paths = config.pop("data_paths", None)
@@ -87,7 +88,10 @@ def __init__(self, config: dict = None):
8788
config = {}
8889
self.train_dataset_config = DatasetConfig(config.pop("train_dataset_config", {}))
8990
self.eval_dataset_config = DatasetConfig(config.pop("eval_dataset_config", {}))
90-
self.test_dataset_config = DatasetConfig(config.pop("test_dataset_config", {}))
91+
self.test_dataset_configs = [DatasetConfig(conf) for conf in config.pop("test_dataset_configs", [])]
92+
_test_dataset_config = config.pop("test_dataset_config", None)
93+
if _test_dataset_config:
94+
self.test_dataset_configs.append(_test_dataset_config)
9195

9296

9397
class LearningConfig:
@@ -114,8 +118,7 @@ def __init__(self, data: Union[str, dict], training=True, **kwargs):
114118
self.decoder_config = DecoderConfig(config.pop("decoder_config", {}))
115119
self.model_config: dict = config.pop("model_config", {})
116120
self.data_config = DataConfig(config.pop("data_config", {}))
117-
_learning_config_dict = config.pop("learning_config", {})
118-
self.learning_config = LearningConfig(_learning_config_dict) if training else None
121+
self.learning_config = LearningConfig(config.pop("learning_config", {})) if training else None
119122
for k, v in config.items():
120123
setattr(self, k, v)
121124
logger.info(str(self))

tensorflow_asr/datasets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(
160160
metadata: str = None,
161161
sample_rate: int = 16000,
162162
stage: str = "train",
163+
name: str = "",
163164
**kwargs,
164165
):
165166
self.data_paths = data_paths or []
@@ -175,6 +176,7 @@ def __init__(
175176
self.total_steps = None # for better training visualization
176177
self.metadata = metadata
177178
self.sample_rate = sample_rate
179+
self.name = name
178180

179181
def parse(self, *args, **kwargs):
180182
raise NotImplementedError()
@@ -199,6 +201,7 @@ def __init__(
199201
metadata: str = None,
200202
buffer_size: int = BUFFER_SIZE,
201203
sample_rate: int = 16000,
204+
name: str = "",
202205
**kwargs,
203206
):
204207
super().__init__(
@@ -212,6 +215,8 @@ def __init__(
212215
metadata=metadata,
213216
indefinite=indefinite,
214217
sample_rate=sample_rate,
218+
name=name,
219+
**kwargs,
215220
)
216221
self.entries = []
217222
self.tokenizer = tokenizer
@@ -413,6 +418,7 @@ def __init__(
413418
buffer_size: int = BUFFER_SIZE,
414419
compression_type: str = "GZIP",
415420
sample_rate: int = 16000,
421+
name: str = "",
416422
**kwargs,
417423
):
418424
super().__init__(
@@ -427,6 +433,7 @@ def __init__(
427433
metadata=metadata,
428434
indefinite=indefinite,
429435
sample_rate=sample_rate,
436+
name=name,
430437
**kwargs,
431438
)
432439
if not self.stage:

0 commit comments

Comments
 (0)