|
56 | 56 |
|
57 | 57 | strategy = setup_strategy(args.devices) |
58 | 58 |
|
59 | | -from tensorflow_asr.configs.user_config import UserConfig |
| 59 | +from tensorflow_asr.configs.config import Config |
60 | 60 | from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset |
61 | 61 | from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer |
62 | 62 | from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer |
63 | 63 | from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA |
64 | 64 | from tensorflow_asr.models.conformer import Conformer |
65 | 65 | from tensorflow_asr.optimizers.schedules import TransformerSchedule |
66 | 66 |
|
67 | | -config = UserConfig(DEFAULT_YAML, args.config, learning=True) |
68 | | -speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) |
69 | | -text_featurizer = CharFeaturizer(config["decoder_config"]) |
| 67 | +config = Config(args.config, learning=True) |
| 68 | +speech_featurizer = TFSpeechFeaturizer(config.speech_config) |
| 69 | +text_featurizer = CharFeaturizer(config.decoder_config) |
70 | 70 |
|
71 | 71 | if args.tfrecords: |
72 | 72 | train_dataset = ASRTFRecordDataset( |
73 | | - data_paths=config["learning_config"]["dataset_config"]["train_paths"], |
74 | | - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], |
| 73 | + data_paths=config.learning_config.dataset_config.train_paths, |
| 74 | + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, |
75 | 75 | speech_featurizer=speech_featurizer, |
76 | 76 | text_featurizer=text_featurizer, |
77 | | - augmentations=config["learning_config"]["augmentations"], |
| 77 | + augmentations=config.learning_config.augmentations, |
78 | 78 | stage="train", cache=args.cache, shuffle=True |
79 | 79 | ) |
80 | 80 | eval_dataset = ASRTFRecordDataset( |
81 | | - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], |
82 | | - tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], |
| 81 | + data_paths=config.learning_config.dataset_config.eval_paths, |
| 82 | + tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir, |
83 | 83 | speech_featurizer=speech_featurizer, |
84 | 84 | text_featurizer=text_featurizer, |
85 | 85 | stage="eval", cache=args.cache, shuffle=True |
86 | 86 | ) |
87 | 87 | else: |
88 | 88 | train_dataset = ASRSliceDataset( |
89 | | - data_paths=config["learning_config"]["dataset_config"]["train_paths"], |
| 89 | + data_paths=config.learning_config.dataset_config.train_paths, |
90 | 90 | speech_featurizer=speech_featurizer, |
91 | 91 | text_featurizer=text_featurizer, |
92 | | - augmentations=config["learning_config"]["augmentations"], |
| 92 | + augmentations=config.learning_config.augmentations, |
93 | 93 | stage="train", cache=args.cache, shuffle=True |
94 | 94 | ) |
95 | 95 | eval_dataset = ASRSliceDataset( |
96 | | - data_paths=config["learning_config"]["dataset_config"]["eval_paths"], |
| 96 | + data_paths=config.learning_config.dataset_config.eval_paths, |
97 | 97 | speech_featurizer=speech_featurizer, |
98 | 98 | text_featurizer=text_featurizer, |
99 | 99 | stage="eval", cache=args.cache, shuffle=True |
100 | 100 | ) |
101 | 101 |
|
102 | 102 | conformer_trainer = TransducerTrainerGA( |
103 | | - config=config["learning_config"]["running_config"], |
| 103 | + config=config.learning_config.running_config, |
104 | 104 | text_featurizer=text_featurizer, strategy=strategy |
105 | 105 | ) |
106 | 106 |
|
107 | 107 | with conformer_trainer.strategy.scope(): |
108 | 108 | # build model |
109 | | - conformer = Conformer( |
110 | | - **config["model_config"], |
111 | | - vocabulary_size=text_featurizer.num_classes |
112 | | - ) |
| 109 | + conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
113 | 110 | conformer._build(speech_featurizer.shape) |
114 | 111 | conformer.summary(line_length=120) |
115 | 112 |
|
116 | | - optimizer_config = config["learning_config"]["optimizer_config"] |
| 113 | + optimizer_config = config.learning_config.optimizer_config |
117 | 114 | optimizer = tf.keras.optimizers.Adam( |
118 | 115 | TransformerSchedule( |
119 | | - d_model=config["model_config"]["dmodel"], |
| 116 | + d_model=config.model_config["dmodel"], |
120 | 117 | warmup_steps=optimizer_config["warmup_steps"], |
121 | | - max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"])) |
| 118 | + max_lr=(0.05 / math.sqrt(config.model_config["dmodel"])) |
122 | 119 | ), |
123 | 120 | beta_1=optimizer_config["beta1"], |
124 | 121 | beta_2=optimizer_config["beta2"], |
|
0 commit comments