Skip to content

Commit 6729e44

Browse files
committed
🚀 Use vocabulary config for characters and subwords
1 parent fc1fd9f commit 6729e44

File tree

11 files changed

+310
-77
lines changed

11 files changed

+310
-77
lines changed

examples/conformer/test_subword_conformer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
parser.add_argument("--cpu", default=False, action="store_true",
4444
help="Whether to only use cpu")
4545

46-
parser.add_argument("--subwords_prefix", type=str, default=None,
47-
help="Prefix of file that stores generated subwords")
46+
parser.add_argument("--subwords", type=str, default=None,
47+
help="Path to file that stores generated subwords")
4848

4949
parser.add_argument("--output_name", type=str, default="test",
5050
help="Result filename name prefix")
@@ -65,12 +65,11 @@
6565
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
6666
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
6767

68-
if args.subwords_prefix and os.path.exists(f"{args.subwords_prefix}.subwords"):
68+
if args.subwords and os.path.exists(args.subwords):
6969
print("Loading subwords ...")
70-
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"],
71-
args.subwords_prefix)
70+
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
7271
else:
73-
raise ValueError("subwords_prefix must be set")
72+
raise ValueError("subwords must be set")
7473

7574
tf.random.set_seed(0)
7675
assert args.saved

examples/conformer/tflite_subword_conformer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
parser.add_argument("--saved", type=str, default=None,
3737
help="Path to saved model")
3838

39-
parser.add_argument("--subwords_prefix", type=str, default=None,
40-
help="Prefix of file that stores generated subwords")
39+
parser.add_argument("--subwords", type=str, default=None,
40+
help="Path to file that stores generated subwords")
4141

4242
parser.add_argument("output", type=str, default=None,
4343
help="TFLite file path to be exported")
@@ -49,12 +49,11 @@
4949
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
5050
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
5151

52-
if args.subwords_prefix and os.path.exists(f"{args.subwords_prefix}.subwords"):
52+
if args.subwords and os.path.exists(args.subwords):
5353
print("Loading subwords ...")
54-
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"],
55-
args.subwords_prefix)
54+
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
5655
else:
57-
raise ValueError("subwords_prefix must be set")
56+
raise ValueError("subwords must be set")
5857

5958
# build model
6059
conformer = Conformer(
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import math
17+
import argparse
18+
from tensorflow_asr.utils import setup_environment, setup_strategy
19+
20+
setup_environment()
21+
import tensorflow as tf
22+
23+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24+
25+
tf.keras.backend.clear_session()
26+
27+
parser = argparse.ArgumentParser(prog="Conformer Training")
28+
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
30+
help="The file path of model configuration file")
31+
32+
parser.add_argument("--max_ckpts", type=int, default=10,
33+
help="Max number of checkpoints to keep")
34+
35+
parser.add_argument("--tfrecords", default=False, action="store_true",
36+
help="Whether to use tfrecords")
37+
38+
parser.add_argument("--tbs", type=int, default=None,
39+
help="Train batch size per replica")
40+
41+
parser.add_argument("--ebs", type=int, default=None,
42+
help="Evaluation batch size per replica")
43+
44+
parser.add_argument("--devices", type=int, nargs="*", default=[0],
45+
help="Devices' ids to apply distributed training")
46+
47+
parser.add_argument("--mxp", default=False, action="store_true",
48+
help="Enable mixed precision")
49+
50+
parser.add_argument("--cache", default=False, action="store_true",
51+
help="Enable caching for dataset")
52+
53+
args = parser.parse_args()
54+
55+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
56+
57+
strategy = setup_strategy(args.devices)
58+
59+
from tensorflow_asr.configs.user_config import UserConfig
60+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
61+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
62+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
63+
from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA
64+
from tensorflow_asr.models.conformer import Conformer
65+
from tensorflow_asr.optimizers.schedules import TransformerSchedule
66+
67+
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
68+
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
69+
text_featurizer = CharFeaturizer(config["decoder_config"])
70+
71+
if args.tfrecords:
72+
train_dataset = ASRTFRecordDataset(
73+
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
74+
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
75+
speech_featurizer=speech_featurizer,
76+
text_featurizer=text_featurizer,
77+
augmentations=config["learning_config"]["augmentations"],
78+
stage="train", cache=args.cache, shuffle=True
79+
)
80+
eval_dataset = ASRTFRecordDataset(
81+
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
82+
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
83+
speech_featurizer=speech_featurizer,
84+
text_featurizer=text_featurizer,
85+
stage="eval", cache=args.cache, shuffle=True
86+
)
87+
else:
88+
train_dataset = ASRSliceDataset(
89+
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
90+
speech_featurizer=speech_featurizer,
91+
text_featurizer=text_featurizer,
92+
augmentations=config["learning_config"]["augmentations"],
93+
stage="train", cache=args.cache, shuffle=True
94+
)
95+
eval_dataset = ASRSliceDataset(
96+
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
97+
speech_featurizer=speech_featurizer,
98+
text_featurizer=text_featurizer,
99+
stage="eval", cache=args.cache, shuffle=True
100+
)
101+
102+
conformer_trainer = TransducerTrainerGA(
103+
config=config["learning_config"]["running_config"],
104+
text_featurizer=text_featurizer, strategy=strategy
105+
)
106+
107+
with conformer_trainer.strategy.scope():
108+
# build model
109+
conformer = Conformer(
110+
**config["model_config"],
111+
vocabulary_size=text_featurizer.num_classes
112+
)
113+
conformer._build(speech_featurizer.shape)
114+
conformer.summary(line_length=120)
115+
116+
optimizer_config = config["learning_config"]["optimizer_config"]
117+
optimizer = tf.keras.optimizers.Adam(
118+
TransformerSchedule(
119+
d_model=config["model_config"]["dmodel"],
120+
warmup_steps=optimizer_config["warmup_steps"],
121+
max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"]))
122+
),
123+
beta_1=optimizer_config["beta1"],
124+
beta_2=optimizer_config["beta2"],
125+
epsilon=optimizer_config["epsilon"]
126+
)
127+
128+
conformer_trainer.compile(model=conformer, optimizer=optimizer,
129+
max_to_keep=args.max_ckpts)
130+
131+
conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import math
17+
import argparse
18+
from tensorflow_asr.utils import setup_environment, setup_strategy
19+
20+
setup_environment()
21+
import tensorflow as tf
22+
23+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24+
25+
tf.keras.backend.clear_session()
26+
27+
parser = argparse.ArgumentParser(prog="Conformer Training")
28+
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
30+
help="The file path of model configuration file")
31+
32+
parser.add_argument("--max_ckpts", type=int, default=10,
33+
help="Max number of checkpoints to keep")
34+
35+
parser.add_argument("--tfrecords", default=False, action="store_true",
36+
help="Whether to use tfrecords")
37+
38+
parser.add_argument("--tbs", type=int, default=None,
39+
help="Train batch size per replica")
40+
41+
parser.add_argument("--ebs", type=int, default=None,
42+
help="Evaluation batch size per replica")
43+
44+
parser.add_argument("--devices", type=int, nargs="*", default=[0],
45+
help="Devices' ids to apply distributed training")
46+
47+
parser.add_argument("--mxp", default=False, action="store_true",
48+
help="Enable mixed precision")
49+
50+
parser.add_argument("--cache", default=False, action="store_true",
51+
help="Enable caching for dataset")
52+
53+
parser.add_argument("--subwords", type=str, default=None,
54+
help="Path to file that stores generated subwords")
55+
56+
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
57+
help="Transcript files for generating subwords")
58+
59+
args = parser.parse_args()
60+
61+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
62+
63+
strategy = setup_strategy(args.devices)
64+
65+
from tensorflow_asr.configs.user_config import UserConfig
66+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
67+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
68+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
69+
from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA
70+
from tensorflow_asr.models.conformer import Conformer
71+
from tensorflow_asr.optimizers.schedules import TransformerSchedule
72+
73+
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
74+
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
75+
76+
if args.subwords and os.path.exists(args.subwords):
77+
print("Loading subwords ...")
78+
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
79+
else:
80+
print("Generating subwords ...")
81+
text_featurizer = SubwordFeaturizer.build_from_corpus(
82+
config["decoder_config"],
83+
corpus_files=args.subwords_corpus
84+
)
85+
text_featurizer.subwords.save_to_file(args.subwords_prefix)
86+
87+
if args.tfrecords:
88+
train_dataset = ASRTFRecordDataset(
89+
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
90+
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
91+
speech_featurizer=speech_featurizer,
92+
text_featurizer=text_featurizer,
93+
augmentations=config["learning_config"]["augmentations"],
94+
stage="train", cache=args.cache, shuffle=True
95+
)
96+
eval_dataset = ASRTFRecordDataset(
97+
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
98+
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
99+
speech_featurizer=speech_featurizer,
100+
text_featurizer=text_featurizer,
101+
stage="eval", cache=args.cache, shuffle=True
102+
)
103+
else:
104+
train_dataset = ASRSliceDataset(
105+
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
106+
speech_featurizer=speech_featurizer,
107+
text_featurizer=text_featurizer,
108+
augmentations=config["learning_config"]["augmentations"],
109+
stage="train", cache=args.cache, shuffle=True
110+
)
111+
eval_dataset = ASRSliceDataset(
112+
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
113+
speech_featurizer=speech_featurizer,
114+
text_featurizer=text_featurizer,
115+
stage="eval", cache=args.cache, shuffle=True
116+
)
117+
118+
conformer_trainer = TransducerTrainerGA(
119+
config=config["learning_config"]["running_config"],
120+
text_featurizer=text_featurizer, strategy=strategy
121+
)
122+
123+
with conformer_trainer.strategy.scope():
124+
# build model
125+
conformer = Conformer(
126+
**config["model_config"],
127+
vocabulary_size=text_featurizer.num_classes
128+
)
129+
conformer._build(speech_featurizer.shape)
130+
conformer.summary(line_length=120)
131+
132+
optimizer_config = config["learning_config"]["optimizer_config"]
133+
optimizer = tf.keras.optimizers.Adam(
134+
TransformerSchedule(
135+
d_model=config["model_config"]["dmodel"],
136+
warmup_steps=optimizer_config["warmup_steps"],
137+
max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"]))
138+
),
139+
beta_1=optimizer_config["beta1"],
140+
beta_2=optimizer_config["beta2"],
141+
epsilon=optimizer_config["epsilon"]
142+
)
143+
144+
conformer_trainer.compile(model=conformer, optimizer=optimizer,
145+
max_to_keep=args.max_ckpts)
146+
147+
conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)

examples/conformer/train_subword_conformer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
parser.add_argument("--cache", default=False, action="store_true",
5151
help="Enable caching for dataset")
5252

53-
parser.add_argument("--subwords_prefix", type=str, default=None,
54-
help="Prefix of file that stores generated subwords")
53+
parser.add_argument("--subwords", type=str, default=None,
54+
help="Path to file that stores generated subwords")
5555

5656
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
5757
help="Transcript files for generating subwords")
@@ -73,10 +73,9 @@
7373
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
7474
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
7575

76-
if args.subwords_prefix and os.path.exists(f"{args.subwords_prefix}.subwords"):
76+
if args.subwords and os.path.exists(args.subwords):
7777
print("Loading subwords ...")
78-
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"],
79-
args.subwords_prefix)
78+
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
8079
else:
8180
print("Generating subwords ...")
8281
text_featurizer = SubwordFeaturizer.build_from_corpus(

setup.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,14 @@
3737

3838
setuptools.setup(
3939
name="TensorFlowASR",
40-
version="0.2.5",
40+
version="0.2.6",
4141
author="Huy Le Nguyen",
4242
author_email="[email protected]",
4343
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
4444
long_description=long_description,
4545
long_description_content_type="text/markdown",
4646
url="https://github.com/TensorSpeech/TensorFlowASR",
4747
packages=setuptools.find_packages(include=["tensorflow_asr*"]),
48-
package_data={
49-
"tensorflow_asr": ["featurizers/*.txt"]
50-
},
5148
install_requires=requirements,
5249
classifiers=[
5350
"Programming Language :: Python :: 3.6",
Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +0,0 @@
1-
# Copyright 2020 Huy Le Nguyen (@usimarit)
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
import os
15-
16-
ENGLISH = os.path.abspath(os.path.join(os.path.dirname(__file__), "english.txt"))

0 commit comments

Comments
 (0)