|
15 | 15 | import os |
16 | 16 | import math |
17 | 17 | import argparse |
18 | | -from tensorflow_asr.utils import setup_environment, setup_strategy |
| 18 | +from tensorflow_asr.utils import env_util |
19 | 19 |
|
20 | | -setup_environment() |
| 20 | +env_util.setup_environment() |
21 | 21 | import tensorflow as tf |
22 | 22 |
|
23 | 23 | DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") |
24 | 24 |
|
25 | 25 | tf.keras.backend.clear_session() |
26 | 26 |
|
27 | | -parser = argparse.ArgumentParser(prog="ContextNet Training") |
| 27 | +parser = argparse.ArgumentParser(prog="Contextnet Training") |
28 | 28 |
|
29 | 29 | parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") |
30 | 30 |
|
31 | | -parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") |
32 | | - |
33 | 31 | parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords") |
34 | 32 |
|
| 33 | +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") |
| 34 | + |
| 35 | +parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords") |
| 36 | + |
35 | 37 | parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica") |
36 | 38 |
|
37 | 39 | parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") |
38 | 40 |
|
39 | 41 | parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") |
40 | 42 |
|
41 | | -parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") |
| 43 | +parser.add_argument("--metadata", type=str, default=None, help="Path to file containing metadata") |
| 44 | + |
| 45 | +parser.add_argument("--static_length", default=False, action="store_true", help="Use static lengths") |
42 | 46 |
|
43 | 47 | parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") |
44 | 48 |
|
45 | 49 | parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") |
46 | 50 |
|
47 | | -parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") |
48 | | - |
49 | | -parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") |
50 | | - |
51 | 51 | args = parser.parse_args() |
52 | 52 |
|
53 | 53 | tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) |
54 | 54 |
|
55 | | -strategy = setup_strategy(args.devices) |
| 55 | +strategy = env_util.setup_strategy(args.devices) |
56 | 56 |
|
57 | 57 | from tensorflow_asr.configs.config import Config |
58 | | -from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras |
59 | | -from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer |
60 | | -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer |
61 | | -from tensorflow_asr.models.keras.contextnet import ContextNet |
| 58 | +from tensorflow_asr.datasets import asr_dataset |
| 59 | +from tensorflow_asr.featurizers import speech_featurizers, text_featurizers |
| 60 | +from tensorflow_asr.models.transducer.contextnet import ContextNet |
62 | 61 | from tensorflow_asr.optimizers.schedules import TransformerSchedule |
63 | 62 |
|
64 | 63 | config = Config(args.config) |
65 | | -speech_featurizer = TFSpeechFeaturizer(config.speech_config) |
| 64 | +speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config) |
66 | 65 |
|
67 | | -if args.subwords and os.path.exists(args.subwords): |
| 66 | +if args.sentence_piece: |
| 67 | + print("Loading SentencePiece model ...") |
| 68 | + text_featurizer = text_featurizers.SentencePieceFeaturizer(config.decoder_config) |
| 69 | +elif args.subwords: |
68 | 70 | print("Loading subwords ...") |
69 | | - text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) |
| 71 | + text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config) |
70 | 72 | else: |
71 | | - print("Generating subwords ...") |
72 | | - text_featurizer = SubwordFeaturizer.build_from_corpus( |
73 | | - config.decoder_config, |
74 | | - corpus_files=args.subwords_corpus |
75 | | - ) |
76 | | - text_featurizer.save_to_file(args.subwords) |
| 73 | + print("Use characters ...") |
| 74 | + text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config) |
77 | 75 |
|
78 | 76 | if args.tfrecords: |
79 | | - train_dataset = ASRTFRecordDatasetKeras( |
80 | | - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
| 77 | + train_dataset = asr_dataset.ASRTFRecordDataset( |
| 78 | + speech_featurizer=speech_featurizer, |
| 79 | + text_featurizer=text_featurizer, |
81 | 80 | **vars(config.learning_config.train_dataset_config), |
82 | 81 | indefinite=True |
83 | 82 | ) |
84 | | - eval_dataset = ASRTFRecordDatasetKeras( |
85 | | - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
| 83 | + eval_dataset = asr_dataset.ASRTFRecordDataset( |
| 84 | + speech_featurizer=speech_featurizer, |
| 85 | + text_featurizer=text_featurizer, |
86 | 86 | **vars(config.learning_config.eval_dataset_config), |
87 | 87 | indefinite=True |
88 | 88 | ) |
89 | | - # Update metadata calculated from both train and eval datasets |
90 | | - train_dataset.load_metadata(args.metadata_prefix) |
91 | | - eval_dataset.load_metadata(args.metadata_prefix) |
92 | | - # Use dynamic length |
93 | | - speech_featurizer.reset_length() |
94 | | - text_featurizer.reset_length() |
95 | 89 | else: |
96 | | - train_dataset = ASRSliceDatasetKeras( |
97 | | - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
| 90 | + train_dataset = asr_dataset.ASRSliceDataset( |
| 91 | + speech_featurizer=speech_featurizer, |
| 92 | + text_featurizer=text_featurizer, |
98 | 93 | **vars(config.learning_config.train_dataset_config), |
99 | 94 | indefinite=True |
100 | 95 | ) |
101 | | - eval_dataset = ASRSliceDatasetKeras( |
102 | | - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
| 96 | + eval_dataset = asr_dataset.ASRSliceDataset( |
| 97 | + speech_featurizer=speech_featurizer, |
| 98 | + text_featurizer=text_featurizer, |
103 | 99 | **vars(config.learning_config.eval_dataset_config), |
104 | 100 | indefinite=True |
105 | 101 | ) |
106 | 102 |
|
107 | | -global_batch_size = config.learning_config.running_config.batch_size |
| 103 | +train_dataset.load_metadata(args.metadata) |
| 104 | +eval_dataset.load_metadata(args.metadata) |
| 105 | + |
| 106 | +if not args.static_length: |
| 107 | + speech_featurizer.reset_length() |
| 108 | + text_featurizer.reset_length() |
| 109 | + |
| 110 | +global_batch_size = args.tbs or config.learning_config.running_config.batch_size |
108 | 111 | global_batch_size *= strategy.num_replicas_in_sync |
109 | 112 |
|
110 | 113 | train_data_loader = train_dataset.create(global_batch_size) |
|
114 | 117 | # build model |
115 | 118 | contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
116 | 119 | contextnet._build(speech_featurizer.shape) |
117 | | - contextnet.summary(line_length=120) |
| 120 | + contextnet.summary(line_length=100) |
118 | 121 |
|
119 | 122 | optimizer = tf.keras.optimizers.Adam( |
120 | 123 | TransformerSchedule( |
121 | 124 | d_model=contextnet.dmodel, |
122 | | - warmup_steps=config.learning_config.optimizer_config["warmup_steps"], |
| 125 | + warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000), |
123 | 126 | max_lr=(0.05 / math.sqrt(contextnet.dmodel)) |
124 | 127 | ), |
125 | | - beta_1=config.learning_config.optimizer_config["beta1"], |
126 | | - beta_2=config.learning_config.optimizer_config["beta2"], |
127 | | - epsilon=config.learning_config.optimizer_config["epsilon"] |
| 128 | + **config.learning_config.optimizer_config |
128 | 129 | ) |
129 | 130 |
|
130 | 131 | contextnet.compile( |
|
141 | 142 | ] |
142 | 143 |
|
143 | 144 | contextnet.fit( |
144 | | - train_data_loader, epochs=config.learning_config.running_config.num_epochs, |
145 | | - validation_data=eval_data_loader, callbacks=callbacks, |
146 | | - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps |
| 145 | + train_data_loader, |
| 146 | + epochs=config.learning_config.running_config.num_epochs, |
| 147 | + validation_data=eval_data_loader, |
| 148 | + callbacks=callbacks, |
| 149 | + steps_per_epoch=train_dataset.total_steps, |
| 150 | + validation_steps=eval_dataset.total_steps |
147 | 151 | ) |
0 commit comments