Skip to content

Commit fd26c7d

Browse files
committed
⚡ Update configurations
1 parent 1d1ecc4 commit fd26c7d

33 files changed

+384
-397
lines changed

examples/conformer/save_conformer_from_weights.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,20 @@
4646

4747
setup_devices([args.device], cpu=args.cpu)
4848

49-
from tensorflow_asr.configs.user_config import UserConfig
49+
from tensorflow_asr.configs.config import Config
5050
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
5151
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
5252
from tensorflow_asr.models.conformer import Conformer
5353

54-
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
55-
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
56-
text_featurizer = CharFeaturizer(config["decoder_config"])
54+
config = Config(args.config, learning=True)
55+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
56+
text_featurizer = CharFeaturizer(config.decoder_config)
5757

5858
tf.random.set_seed(0)
5959
assert args.saved
6060

6161
# build model
62-
conformer = Conformer(
63-
vocabulary_size=text_featurizer.num_classes,
64-
**config["model_config"]
65-
)
62+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6663
conformer._build(speech_featurizer.shape)
6764
conformer.load_weights(args.saved, by_name=True)
6865
conformer.summary(line_length=150)

examples/conformer/test_conformer.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,48 +52,45 @@
5252

5353
setup_devices([args.device], cpu=args.cpu)
5454

55-
from tensorflow_asr.configs.user_config import UserConfig
55+
from tensorflow_asr.configs.config import Config
5656
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
5757
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
5858
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
5959
from tensorflow_asr.runners.base_runners import BaseTester
6060
from tensorflow_asr.models.conformer import Conformer
6161

62-
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
63-
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
64-
text_featurizer = CharFeaturizer(config["decoder_config"])
62+
config = Config(args.config, learning=True)
63+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
64+
text_featurizer = CharFeaturizer(config.decoder_config)
6565

6666
tf.random.set_seed(0)
6767
assert args.saved
6868

6969
if args.tfrecords:
7070
test_dataset = ASRTFRecordDataset(
71-
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
72-
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
71+
data_paths=config.learning_config.dataset_config.test_paths,
72+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
7373
speech_featurizer=speech_featurizer,
7474
text_featurizer=text_featurizer,
7575
stage="test", shuffle=False
7676
)
7777
else:
7878
test_dataset = ASRSliceDataset(
79-
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
79+
data_paths=config.learning_config.dataset_config.test_paths,
8080
speech_featurizer=speech_featurizer,
8181
text_featurizer=text_featurizer,
8282
stage="test", shuffle=False
8383
)
8484

8585
# build model
86-
conformer = Conformer(
87-
vocabulary_size=text_featurizer.num_classes,
88-
**config["model_config"]
89-
)
86+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
9087
conformer._build(speech_featurizer.shape)
9188
conformer.load_weights(args.saved, by_name=True)
9289
conformer.summary(line_length=120)
9390
conformer.add_featurizers(speech_featurizer, text_featurizer)
9491

9592
conformer_tester = BaseTester(
96-
config=config["learning_config"]["running_config"],
93+
config=config.learning_config.running_config,
9794
output_name=args.output_name
9895
)
9996
conformer_tester.compile(conformer)

examples/conformer/test_subword_conformer.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,19 @@
5555

5656
setup_devices([args.device], cpu=args.cpu)
5757

58-
from tensorflow_asr.configs.user_config import UserConfig
58+
from tensorflow_asr.configs.config import Config
5959
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
6060
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
6161
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
6262
from tensorflow_asr.runners.base_runners import BaseTester
6363
from tensorflow_asr.models.conformer import Conformer
6464

65-
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
66-
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
65+
config = Config(args.config, learning=True)
66+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
6767

6868
if args.subwords and os.path.exists(args.subwords):
6969
print("Loading subwords ...")
70-
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
70+
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
7171
else:
7272
raise ValueError("subwords must be set")
7373

@@ -76,32 +76,29 @@
7676

7777
if args.tfrecords:
7878
test_dataset = ASRTFRecordDataset(
79-
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
80-
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
79+
data_paths=config.learning_config.dataset_config.test_paths,
80+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
8181
speech_featurizer=speech_featurizer,
8282
text_featurizer=text_featurizer,
8383
stage="test", shuffle=False
8484
)
8585
else:
8686
test_dataset = ASRSliceDataset(
87-
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
87+
data_paths=config.learning_config.dataset_config.test_paths,
8888
speech_featurizer=speech_featurizer,
8989
text_featurizer=text_featurizer,
9090
stage="test", shuffle=False
9191
)
9292

9393
# build model
94-
conformer = Conformer(
95-
vocabulary_size=text_featurizer.num_classes,
96-
**config["model_config"]
97-
)
94+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
9895
conformer._build(speech_featurizer.shape)
9996
conformer.load_weights(args.saved, by_name=True)
10097
conformer.summary(line_length=120)
10198
conformer.add_featurizers(speech_featurizer, text_featurizer)
10299

103100
conformer_tester = BaseTester(
104-
config=config["learning_config"]["running_config"],
101+
config=config.learning_config.running_config,
105102
output_name=args.output_name
106103
)
107104
conformer_tester.compile(conformer)

examples/conformer/tflite_conformer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
setup_environment()
2020
import tensorflow as tf
2121

22-
from tensorflow_asr.configs.user_config import UserConfig
22+
from tensorflow_asr.configs.config import Config
2323
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
2424
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
2525
from tensorflow_asr.models.conformer import Conformer
@@ -43,15 +43,12 @@
4343

4444
assert args.saved and args.output
4545

46-
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
47-
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
48-
text_featurizer = CharFeaturizer(config["decoder_config"])
46+
config = Config(args.config, learning=True)
47+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
48+
text_featurizer = CharFeaturizer(config.decoder_config)
4949

5050
# build model
51-
conformer = Conformer(
52-
**config["model_config"],
53-
vocabulary_size=text_featurizer.num_classes
54-
)
51+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
5552
conformer._build(speech_featurizer.shape)
5653
conformer.load_weights(args.saved)
5754
conformer.summary(line_length=150)

examples/conformer/tflite_subword_conformer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
setup_environment()
2020
import tensorflow as tf
2121

22-
from tensorflow_asr.configs.user_config import UserConfig
22+
from tensorflow_asr.configs.config import Config
2323
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
2424
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
2525
from tensorflow_asr.models.conformer import Conformer
@@ -46,20 +46,17 @@
4646

4747
assert args.saved and args.output
4848

49-
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
50-
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
49+
config = Config(args.config, learning=True)
50+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
5151

5252
if args.subwords and os.path.exists(args.subwords):
5353
print("Loading subwords ...")
54-
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
54+
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
5555
else:
5656
raise ValueError("subwords must be set")
5757

5858
# build model
59-
conformer = Conformer(
60-
**config["model_config"],
61-
vocabulary_size=text_featurizer.num_classes
62-
)
59+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6360
conformer._build(speech_featurizer.shape)
6461
conformer.load_weights(args.saved)
6562
conformer.summary(line_length=150)

examples/conformer/train_conformer.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,69 +56,66 @@
5656

5757
strategy = setup_strategy(args.devices)
5858

59-
from tensorflow_asr.configs.user_config import UserConfig
59+
from tensorflow_asr.configs.config import Config
6060
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
6161
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
6262
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
6363
from tensorflow_asr.runners.transducer_runners import TransducerTrainer
6464
from tensorflow_asr.models.conformer import Conformer
6565
from tensorflow_asr.optimizers.schedules import TransformerSchedule
6666

67-
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
68-
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
69-
text_featurizer = CharFeaturizer(config["decoder_config"])
67+
config = Config(args.config, learning=True)
68+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
69+
text_featurizer = CharFeaturizer(config.decoder_config)
7070

7171
if args.tfrecords:
7272
train_dataset = ASRTFRecordDataset(
73-
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
74-
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
73+
data_paths=config.learning_config.dataset_config.train_paths,
74+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
7575
speech_featurizer=speech_featurizer,
7676
text_featurizer=text_featurizer,
77-
augmentations=config["learning_config"]["augmentations"],
77+
augmentations=config.learning_config.augmentations,
7878
stage="train", cache=args.cache, shuffle=True
7979
)
8080
eval_dataset = ASRTFRecordDataset(
81-
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
82-
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
81+
data_paths=config.learning_config.dataset_config.eval_paths,
82+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
8383
speech_featurizer=speech_featurizer,
8484
text_featurizer=text_featurizer,
8585
stage="eval", cache=args.cache, shuffle=True
8686
)
8787
else:
8888
train_dataset = ASRSliceDataset(
89-
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
89+
data_paths=config.learning_config.dataset_config.train_paths,
9090
speech_featurizer=speech_featurizer,
9191
text_featurizer=text_featurizer,
92-
augmentations=config["learning_config"]["augmentations"],
92+
augmentations=config.learning_config.augmentations,
9393
stage="train", cache=args.cache, shuffle=True
9494
)
9595
eval_dataset = ASRSliceDataset(
96-
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
96+
data_paths=config.learning_config.dataset_config.eval_paths,
9797
speech_featurizer=speech_featurizer,
9898
text_featurizer=text_featurizer,
9999
stage="eval", cache=args.cache, shuffle=True
100100
)
101101

102102
conformer_trainer = TransducerTrainer(
103-
config=config["learning_config"]["running_config"],
103+
config=config.learning_config.running_config,
104104
text_featurizer=text_featurizer, strategy=strategy
105105
)
106106

107107
with conformer_trainer.strategy.scope():
108108
# build model
109-
conformer = Conformer(
110-
**config["model_config"],
111-
vocabulary_size=text_featurizer.num_classes
112-
)
109+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
113110
conformer._build(speech_featurizer.shape)
114111
conformer.summary(line_length=120)
115112

116-
optimizer_config = config["learning_config"]["optimizer_config"]
113+
optimizer_config = config.learning_config.optimizer_config
117114
optimizer = tf.keras.optimizers.Adam(
118115
TransformerSchedule(
119-
d_model=config["model_config"]["dmodel"],
116+
d_model=config.model_config["dmodel"],
120117
warmup_steps=optimizer_config["warmup_steps"],
121-
max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"]))
118+
max_lr=(0.05 / math.sqrt(config.model_config["dmodel"]))
122119
),
123120
beta_1=optimizer_config["beta1"],
124121
beta_2=optimizer_config["beta2"],

examples/conformer/train_ga_conformer.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,69 +56,66 @@
5656

5757
strategy = setup_strategy(args.devices)
5858

59-
from tensorflow_asr.configs.user_config import UserConfig
59+
from tensorflow_asr.configs.config import Config
6060
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
6161
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
6262
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
6363
from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA
6464
from tensorflow_asr.models.conformer import Conformer
6565
from tensorflow_asr.optimizers.schedules import TransformerSchedule
6666

67-
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
68-
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
69-
text_featurizer = CharFeaturizer(config["decoder_config"])
67+
config = Config(args.config, learning=True)
68+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
69+
text_featurizer = CharFeaturizer(config.decoder_config)
7070

7171
if args.tfrecords:
7272
train_dataset = ASRTFRecordDataset(
73-
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
74-
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
73+
data_paths=config.learning_config.dataset_config.train_paths,
74+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
7575
speech_featurizer=speech_featurizer,
7676
text_featurizer=text_featurizer,
77-
augmentations=config["learning_config"]["augmentations"],
77+
augmentations=config.learning_config.augmentations,
7878
stage="train", cache=args.cache, shuffle=True
7979
)
8080
eval_dataset = ASRTFRecordDataset(
81-
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
82-
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
81+
data_paths=config.learning_config.dataset_config.eval_paths,
82+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
8383
speech_featurizer=speech_featurizer,
8484
text_featurizer=text_featurizer,
8585
stage="eval", cache=args.cache, shuffle=True
8686
)
8787
else:
8888
train_dataset = ASRSliceDataset(
89-
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
89+
data_paths=config.learning_config.dataset_config.train_paths,
9090
speech_featurizer=speech_featurizer,
9191
text_featurizer=text_featurizer,
92-
augmentations=config["learning_config"]["augmentations"],
92+
augmentations=config.learning_config.augmentations,
9393
stage="train", cache=args.cache, shuffle=True
9494
)
9595
eval_dataset = ASRSliceDataset(
96-
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
96+
data_paths=config.learning_config.dataset_config.eval_paths,
9797
speech_featurizer=speech_featurizer,
9898
text_featurizer=text_featurizer,
9999
stage="eval", cache=args.cache, shuffle=True
100100
)
101101

102102
conformer_trainer = TransducerTrainerGA(
103-
config=config["learning_config"]["running_config"],
103+
config=config.learning_config.running_config,
104104
text_featurizer=text_featurizer, strategy=strategy
105105
)
106106

107107
with conformer_trainer.strategy.scope():
108108
# build model
109-
conformer = Conformer(
110-
**config["model_config"],
111-
vocabulary_size=text_featurizer.num_classes
112-
)
109+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
113110
conformer._build(speech_featurizer.shape)
114111
conformer.summary(line_length=120)
115112

116-
optimizer_config = config["learning_config"]["optimizer_config"]
113+
optimizer_config = config.learning_config.optimizer_config
117114
optimizer = tf.keras.optimizers.Adam(
118115
TransformerSchedule(
119-
d_model=config["model_config"]["dmodel"],
116+
d_model=config.model_config["dmodel"],
120117
warmup_steps=optimizer_config["warmup_steps"],
121-
max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"]))
118+
max_lr=(0.05 / math.sqrt(config.model_config["dmodel"]))
122119
),
123120
beta_1=optimizer_config["beta1"],
124121
beta_2=optimizer_config["beta2"],

0 commit comments

Comments
 (0)