1+ {
2+ "metadata" : {
3+ "language_info" : {
4+ "codemirror_mode" : {
5+ "name" : " ipython" ,
6+ "version" : 3
7+ },
8+ "file_extension" : " .py" ,
9+ "mimetype" : " text/x-python" ,
10+ "name" : " python" ,
11+ "nbconvert_exporter" : " python" ,
12+ "pygments_lexer" : " ipython3" ,
13+ "version" : " 3.8.8-final"
14+ },
15+ "orig_nbformat" : 2 ,
16+ "kernelspec" : {
17+ "name" : " python388jvsc74a57bd045f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f" ,
18+ "display_name" : " Python 3.8.8 64-bit ('tfo': conda)"
19+ },
20+ "metadata" : {
21+ "interpreter" : {
22+ "hash" : " 45f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
23+ }
24+ }
25+ },
26+ "nbformat" : 4 ,
27+ "nbformat_minor" : 2 ,
28+ "cells" : [
29+ {
30+ "cell_type" : " code" ,
31+ "execution_count" : null ,
32+ "metadata" : {},
33+ "outputs" : [],
34+ "source" : [
35+ " config = {\n " ,
36+ " \" speech_config\" : {\n " ,
37+ " \" sample_rate\" : 16000,\n " ,
38+ " \" frame_ms\" : 25,\n " ,
39+ " \" stride_ms\" : 10,\n " ,
40+ " \" num_feature_bins\" : 80,\n " ,
41+ " \" feature_type\" : \" log_mel_spectrogram\" ,\n " ,
42+ " \" preemphasis\" : 0.97,\n " ,
43+ " \" normalize_signal\" : True,\n " ,
44+ " \" normalize_feature\" : True,\n " ,
45+ " \" normalize_per_feature\" : False,\n " ,
46+ " },\n " ,
47+ " \" decoder_config\" : {\n " ,
48+ " \" vocabulary\" : None,\n " ,
49+ " \" target_vocab_size\" : 1000,\n " ,
50+ " \" max_subword_length\" : 10,\n " ,
51+ " \" blank_at_zero\" : True,\n " ,
52+ " \" beam_width\" : 0,\n " ,
53+ " \" norm_score\" : True,\n " ,
54+ " \" corpus_files\" : None,\n " ,
55+ " },\n " ,
56+ " \" model_config\" : {\n " ,
57+ " \" name\" : \" conformer\" ,\n " ,
58+ " \" encoder_subsampling\" : {\n " ,
59+ " \" type\" : \" conv2d\" ,\n " ,
60+ " \" filters\" : 144,\n " ,
61+ " \" kernel_size\" : 3,\n " ,
62+ " \" strides\" : 2,\n " ,
63+ " },\n " ,
64+ " \" encoder_positional_encoding\" : \" sinusoid_concat\" ,\n " ,
65+ " \" encoder_dmodel\" : 144,\n " ,
66+ " \" encoder_num_blocks\" : 16,\n " ,
67+ " \" encoder_head_size\" : 36,\n " ,
68+ " \" encoder_num_heads\" : 4,\n " ,
69+ " \" encoder_mha_type\" : \" relmha\" ,\n " ,
70+ " \" encoder_kernel_size\" : 32,\n " ,
71+ " \" encoder_fc_factor\" : 0.5,\n " ,
72+ " \" encoder_dropout\" : 0.1,\n " ,
73+ " \" prediction_embed_dim\" : 320,\n " ,
74+ " \" prediction_embed_dropout\" : 0,\n " ,
75+ " \" prediction_num_rnns\" : 1,\n " ,
76+ " \" prediction_rnn_units\" : 320,\n " ,
77+ " \" prediction_rnn_type\" : \" lstm\" ,\n " ,
78+ " \" prediction_rnn_implementation\" : 2,\n " ,
79+ " \" prediction_layer_norm\" : True,\n " ,
80+ " \" prediction_projection_units\" : 0,\n " ,
81+ " \" joint_dim\" : 320,\n " ,
82+ " \" prejoint_linear\" : True,\n " ,
83+ " \" joint_activation\" : \" tanh\" ,\n " ,
84+ " \" joint_mode\" : \" add\" ,\n " ,
85+ " },\n " ,
86+ " \" learning_config\" : {\n " ,
87+ " \" train_dataset_config\" : {\n " ,
88+ " \" use_tf\" : True,\n " ,
89+ " \" augmentation_config\" : {\n " ,
90+ " \" feature_augment\" : {\n " ,
91+ " \" time_masking\" : {\n " ,
92+ " \" num_masks\" : 10,\n " ,
93+ " \" mask_factor\" : 100,\n " ,
94+ " \" p_upperbound\" : 0.05,\n " ,
95+ " },\n " ,
96+ " \" freq_masking\" : {\" num_masks\" : 1, \" mask_factor\" : 27},\n " ,
97+ " }\n " ,
98+ " },\n " ,
99+ " \" data_paths\" : [\n " ,
100+ " \" /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv\"\n " ,
101+ " ],\n " ,
102+ " \" tfrecords_dir\" : None,\n " ,
103+ " \" shuffle\" : True,\n " ,
104+ " \" cache\" : True,\n " ,
105+ " \" buffer_size\" : 100,\n " ,
106+ " \" drop_remainder\" : True,\n " ,
107+ " \" stage\" : \" train\" ,\n " ,
108+ " },\n " ,
109+ " \" eval_dataset_config\" : {\n " ,
110+ " \" use_tf\" : True,\n " ,
111+ " \" data_paths\" : None,\n " ,
112+ " \" tfrecords_dir\" : None,\n " ,
113+ " \" shuffle\" : False,\n " ,
114+ " \" cache\" : True,\n " ,
115+ " \" buffer_size\" : 100,\n " ,
116+ " \" drop_remainder\" : True,\n " ,
117+ " \" stage\" : \" eval\" ,\n " ,
118+ " },\n " ,
119+ " \" test_dataset_config\" : {\n " ,
120+ " \" use_tf\" : True,\n " ,
121+ " \" data_paths\" : None,\n " ,
122+ " \" tfrecords_dir\" : None,\n " ,
123+ " \" shuffle\" : False,\n " ,
124+ " \" cache\" : True,\n " ,
125+ " \" buffer_size\" : 100,\n " ,
126+ " \" drop_remainder\" : True,\n " ,
127+ " \" stage\" : \" test\" ,\n " ,
128+ " },\n " ,
129+ " \" optimizer_config\" : {\n " ,
130+ " \" warmup_steps\" : 40000,\n " ,
131+ " \" beta_1\" : 0.9,\n " ,
132+ " \" beta_2\" : 0.98,\n " ,
133+ " \" epsilon\" : 1e-09,\n " ,
134+ " },\n " ,
135+ " \" running_config\" : {\n " ,
136+ " \" batch_size\" : 2,\n " ,
137+ " \" num_epochs\" : 50,\n " ,
138+ " \" checkpoint\" : {\n " ,
139+ " \" filepath\" : \" /mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5\" ,\n " ,
140+ " \" save_best_only\" : True,\n " ,
141+ " \" save_weights_only\" : True,\n " ,
142+ " \" save_freq\" : \" epoch\" ,\n " ,
143+ " },\n " ,
144+ " \" states_dir\" : \" /mnt/e/Models/local/conformer/states\" ,\n " ,
145+ " \" tensorboard\" : {\n " ,
146+ " \" log_dir\" : \" /mnt/e/Models/local/conformer/tensorboard\" ,\n " ,
147+ " \" histogram_freq\" : 1,\n " ,
148+ " \" write_graph\" : True,\n " ,
149+ " \" write_images\" : True,\n " ,
150+ " \" update_freq\" : \" epoch\" ,\n " ,
151+ " \" profile_batch\" : 2,\n " ,
152+ " },\n " ,
153+ " },\n " ,
154+ " },\n " ,
155+ " }"
156+ ]
157+ },
158+ {
159+ "cell_type" : " code" ,
160+ "execution_count" : null ,
161+ "metadata" : {},
162+ "outputs" : [],
163+ "source" : [
164+ " metadata = {\n " ,
165+ " \" train\" : {\" max_input_length\" : 2974, \" max_label_length\" : 194, \" num_entries\" : 281241},\n " ,
166+ " \" eval\" : {\" max_input_length\" : 3516, \" max_label_length\" : 186, \" num_entries\" : 5567},\n " ,
167+ " }"
168+ ]
169+ },
170+ {
171+ "cell_type" : " code" ,
172+ "execution_count" : null ,
173+ "metadata" : {},
174+ "outputs" : [],
175+ "source" : [
176+ " import os\n " ,
177+ " import math\n " ,
178+ " import argparse\n " ,
179+ " from tensorflow_asr.utils import env_util\n " ,
180+ " \n " ,
181+ " env_util.setup_environment()\n " ,
182+ " import tensorflow as tf\n " ,
183+ " \n " ,
184+ " tf.keras.backend.clear_session()\n " ,
185+ " tf.config.optimizer.set_experimental_options({\" auto_mixed_precision\" : True})\n " ,
186+ " strategy = env_util.setup_strategy([0])\n " ,
187+ " \n " ,
188+ " from tensorflow_asr.configs.config import Config\n " ,
189+ " from tensorflow_asr.datasets import asr_dataset\n " ,
190+ " from tensorflow_asr.featurizers import speech_featurizers, text_featurizers\n " ,
191+ " from tensorflow_asr.models.transducer.conformer import Conformer\n " ,
192+ " from tensorflow_asr.optimizers.schedules import TransformerSchedule\n " ,
193+ " \n " ,
194+ " config = Config(config)\n " ,
195+ " speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)\n " ,
196+ " \n " ,
197+ " text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)\n " ,
198+ " \n " ,
199+ " train_dataset = asr_dataset.ASRSliceDataset(\n " ,
200+ " speech_featurizer=speech_featurizer,\n " ,
201+ " text_featurizer=text_featurizer,\n " ,
202+ " **vars(config.learning_config.train_dataset_config),\n " ,
203+ " indefinite=True\n " ,
204+ " )\n " ,
205+ " eval_dataset = asr_dataset.ASRSliceDataset(\n " ,
206+ " speech_featurizer=speech_featurizer,\n " ,
207+ " text_featurizer=text_featurizer,\n " ,
208+ " **vars(config.learning_config.eval_dataset_config),\n " ,
209+ " indefinite=True\n " ,
210+ " )\n " ,
211+ " \n " ,
212+ " train_dataset.load_metadata(metadata)\n " ,
213+ " eval_dataset.load_metadata(metadata)\n " ,
214+ " speech_featurizer.reset_length()\n " ,
215+ " text_featurizer.reset_length()\n " ,
216+ " \n " ,
217+ " global_batch_size = config.learning_config.running_config.batch_size\n " ,
218+ " global_batch_size *= strategy.num_replicas_in_sync\n " ,
219+ " \n " ,
220+ " train_data_loader = train_dataset.create(global_batch_size)\n " ,
221+ " eval_data_loader = eval_dataset.create(global_batch_size)\n " ,
222+ " \n " ,
223+ " with strategy.scope():\n " ,
224+ " # build model\n " ,
225+ " conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)\n " ,
226+ " conformer._build(speech_featurizer.shape)\n " ,
227+ " conformer.summary(line_length=100)\n " ,
228+ " \n " ,
229+ " optimizer = tf.keras.optimizers.Adam(\n " ,
230+ " TransformerSchedule(\n " ,
231+ " d_model=conformer.dmodel,\n " ,
232+ " warmup_steps=config.learning_config.optimizer_config.pop(\" warmup_steps\" , 10000),\n " ,
233+ " max_lr=(0.05 / math.sqrt(conformer.dmodel))\n " ,
234+ " ),\n " ,
235+ " **config.learning_config.optimizer_config\n " ,
236+ " )\n " ,
237+ " \n " ,
238+ " conformer.compile(\n " ,
239+ " optimizer=optimizer,\n " ,
240+ " experimental_steps_per_execution=10,\n " ,
241+ " global_batch_size=global_batch_size,\n " ,
242+ " blank=text_featurizer.blank\n " ,
243+ " )\n " ,
244+ " \n " ,
245+ " callbacks = [\n " ,
246+ " tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),\n " ,
247+ " tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),\n " ,
248+ " tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)\n " ,
249+ " ]\n " ,
250+ " \n " ,
251+ " conformer.fit(\n " ,
252+ " train_data_loader,\n " ,
253+ " epochs=config.learning_config.running_config.num_epochs,\n " ,
254+ " validation_data=eval_data_loader,\n " ,
255+ " callbacks=callbacks,\n " ,
256+ " steps_per_epoch=train_dataset.total_steps,\n " ,
257+ " validation_steps=eval_dataset.total_steps\n " ,
258+ " )"
259+ ]
260+ },
261+ {
262+ "cell_type" : " code" ,
263+ "execution_count" : null ,
264+ "metadata" : {},
265+ "outputs" : [],
266+ "source" : []
267+ }
268+ ]
269+ }
0 commit comments