Skip to content

Commit 4842078

Browse files
authored
Merge pull request #98 from TensorSpeech/fix/dataset
Add buffer size and tfrecords shards options
2 parents 488b14e + fdd5722 commit 4842078

File tree

15 files changed

+276
-246
lines changed

15 files changed

+276
-246
lines changed

examples/conformer/train_conformer.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,25 @@
2626

2727
parser = argparse.ArgumentParser(prog="Conformer Training")
2828

29-
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
30-
help="The file path of model configuration file")
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3130

32-
parser.add_argument("--max_ckpts", type=int, default=10,
33-
help="Max number of checkpoints to keep")
31+
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
3432

35-
parser.add_argument("--tfrecords", default=False, action="store_true",
36-
help="Whether to use tfrecords")
33+
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")
3734

38-
parser.add_argument("--tbs", type=int, default=None,
39-
help="Train batch size per replica")
35+
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")
4036

41-
parser.add_argument("--ebs", type=int, default=None,
42-
help="Evaluation batch size per replica")
37+
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")
4338

44-
parser.add_argument("--devices", type=int, nargs="*", default=[0],
45-
help="Devices' ids to apply distributed training")
39+
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
4640

47-
parser.add_argument("--mxp", default=False, action="store_true",
48-
help="Enable mixed precision")
41+
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
4942

50-
parser.add_argument("--cache", default=False, action="store_true",
51-
help="Enable caching for dataset")
43+
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
44+
45+
parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")
46+
47+
parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")
5248

5349
args = parser.parse_args()
5450

@@ -75,28 +71,34 @@
7571
speech_featurizer=speech_featurizer,
7672
text_featurizer=text_featurizer,
7773
augmentations=config.learning_config.augmentations,
78-
stage="train", cache=args.cache, shuffle=True
74+
tfrecords_shards=args.tfrecords_shards,
75+
stage="train", cache=args.cache,
76+
shuffle=True, buffer_size=args.bfs,
7977
)
8078
eval_dataset = ASRTFRecordDataset(
8179
data_paths=config.learning_config.dataset_config.eval_paths,
8280
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
81+
tfrecords_shards=args.tfrecords_shards,
8382
speech_featurizer=speech_featurizer,
8483
text_featurizer=text_featurizer,
85-
stage="eval", cache=args.cache, shuffle=True
84+
stage="eval", cache=args.cache,
85+
shuffle=True, buffer_size=args.bfs,
8686
)
8787
else:
8888
train_dataset = ASRSliceDataset(
8989
data_paths=config.learning_config.dataset_config.train_paths,
9090
speech_featurizer=speech_featurizer,
9191
text_featurizer=text_featurizer,
9292
augmentations=config.learning_config.augmentations,
93-
stage="train", cache=args.cache, shuffle=True
93+
stage="train", cache=args.cache,
94+
shuffle=True, buffer_size=args.bfs,
9495
)
9596
eval_dataset = ASRSliceDataset(
9697
data_paths=config.learning_config.dataset_config.eval_paths,
9798
speech_featurizer=speech_featurizer,
9899
text_featurizer=text_featurizer,
99-
stage="eval", cache=args.cache, shuffle=True
100+
stage="eval", cache=args.cache,
101+
shuffle=True, buffer_size=args.bfs,
100102
)
101103

102104
conformer_trainer = TransducerTrainer(

examples/conformer/train_ga_conformer.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,27 @@
2626

2727
parser = argparse.ArgumentParser(prog="Conformer Training")
2828

29-
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
30-
help="The file path of model configuration file")
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3130

32-
parser.add_argument("--max_ckpts", type=int, default=10,
33-
help="Max number of checkpoints to keep")
31+
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
3432

35-
parser.add_argument("--tfrecords", default=False, action="store_true",
36-
help="Whether to use tfrecords")
33+
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")
3734

38-
parser.add_argument("--tbs", type=int, default=None,
39-
help="Train batch size per replica")
35+
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")
4036

41-
parser.add_argument("--ebs", type=int, default=None,
42-
help="Evaluation batch size per replica")
37+
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")
4338

44-
parser.add_argument("--acs", type=int, default=None,
45-
help="Train accumulation steps")
39+
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
4640

47-
parser.add_argument("--devices", type=int, nargs="*", default=[0],
48-
help="Devices' ids to apply distributed training")
41+
parser.add_argument("--acs", type=int, default=None, help="Train accumulation steps")
4942

50-
parser.add_argument("--mxp", default=False, action="store_true",
51-
help="Enable mixed precision")
43+
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
5244

53-
parser.add_argument("--cache", default=False, action="store_true",
54-
help="Enable caching for dataset")
45+
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
46+
47+
parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")
48+
49+
parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")
5550

5651
args = parser.parse_args()
5752

@@ -78,28 +73,34 @@
7873
speech_featurizer=speech_featurizer,
7974
text_featurizer=text_featurizer,
8075
augmentations=config.learning_config.augmentations,
81-
stage="train", cache=args.cache, shuffle=True
76+
tfrecords_shards=args.tfrecords_shards,
77+
stage="train", cache=args.cache,
78+
shuffle=True, buffer_size=args.bfs,
8279
)
8380
eval_dataset = ASRTFRecordDataset(
8481
data_paths=config.learning_config.dataset_config.eval_paths,
8582
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
83+
tfrecords_shards=args.tfrecords_shards,
8684
speech_featurizer=speech_featurizer,
8785
text_featurizer=text_featurizer,
88-
stage="eval", cache=args.cache, shuffle=True
86+
stage="eval", cache=args.cache,
87+
shuffle=True, buffer_size=args.bfs,
8988
)
9089
else:
9190
train_dataset = ASRSliceDataset(
9291
data_paths=config.learning_config.dataset_config.train_paths,
9392
speech_featurizer=speech_featurizer,
9493
text_featurizer=text_featurizer,
9594
augmentations=config.learning_config.augmentations,
96-
stage="train", cache=args.cache, shuffle=True
95+
stage="train", cache=args.cache,
96+
shuffle=True, buffer_size=args.bfs,
9797
)
9898
eval_dataset = ASRSliceDataset(
9999
data_paths=config.learning_config.dataset_config.eval_paths,
100100
speech_featurizer=speech_featurizer,
101101
text_featurizer=text_featurizer,
102-
stage="eval", cache=args.cache, shuffle=True
102+
stage="eval", cache=args.cache,
103+
shuffle=True, buffer_size=args.bfs,
103104
)
104105

105106
conformer_trainer = TransducerTrainerGA(

examples/conformer/train_ga_subword_conformer.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,41 +26,33 @@
2626

2727
parser = argparse.ArgumentParser(prog="Conformer Training")
2828

29-
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
30-
help="The file path of model configuration file")
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3130

32-
parser.add_argument("--max_ckpts", type=int, default=10,
33-
help="Max number of checkpoints to keep")
31+
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
3432

35-
parser.add_argument("--tfrecords", default=False, action="store_true",
36-
help="Whether to use tfrecords")
33+
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")
3734

38-
parser.add_argument("--tbs", type=int, default=None,
39-
help="Train batch size per replica")
35+
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")
4036

41-
parser.add_argument("--ebs", type=int, default=None,
42-
help="Evaluation batch size per replica")
37+
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")
4338

44-
parser.add_argument("--acs", type=int, default=None,
45-
help="Train accumulation steps")
39+
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
4640

47-
parser.add_argument("--sentence_piece", default=False, action="store_true",
48-
help="Whether to use `SentencePiece` model")
41+
parser.add_argument("--acs", type=int, default=None, help="Train accumulation steps")
4942

50-
parser.add_argument("--devices", type=int, nargs="*", default=[0],
51-
help="Devices' ids to apply distributed training")
43+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
5244

53-
parser.add_argument("--mxp", default=False, action="store_true",
54-
help="Enable mixed precision")
45+
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
5546

56-
parser.add_argument("--cache", default=False, action="store_true",
57-
help="Enable caching for dataset")
47+
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
5848

59-
parser.add_argument("--subwords", type=str, default=None,
60-
help="Path to file that stores generated subwords")
49+
parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")
6150

62-
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
63-
help="Transcript files for generating subwords")
51+
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
52+
53+
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
54+
55+
parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")
6456

6557
args = parser.parse_args()
6658

@@ -100,28 +92,34 @@
10092
speech_featurizer=speech_featurizer,
10193
text_featurizer=text_featurizer,
10294
augmentations=config.learning_config.augmentations,
103-
stage="train", cache=args.cache, shuffle=True
95+
tfrecords_shards=args.tfrecords_shards,
96+
stage="train", cache=args.cache,
97+
shuffle=True, buffer_size=args.bfs,
10498
)
10599
eval_dataset = ASRTFRecordDataset(
106100
data_paths=config.learning_config.dataset_config.eval_paths,
107101
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
102+
tfrecords_shards=args.tfrecords_shards,
108103
speech_featurizer=speech_featurizer,
109104
text_featurizer=text_featurizer,
110-
stage="eval", cache=args.cache, shuffle=True
105+
stage="eval", cache=args.cache,
106+
shuffle=True, buffer_size=args.bfs,
111107
)
112108
else:
113109
train_dataset = ASRSliceDataset(
114110
data_paths=config.learning_config.dataset_config.train_paths,
115111
speech_featurizer=speech_featurizer,
116112
text_featurizer=text_featurizer,
117113
augmentations=config.learning_config.augmentations,
118-
stage="train", cache=args.cache, shuffle=True
114+
stage="train", cache=args.cache,
115+
shuffle=True, buffer_size=args.bfs,
119116
)
120117
eval_dataset = ASRSliceDataset(
121118
data_paths=config.learning_config.dataset_config.eval_paths,
122119
speech_featurizer=speech_featurizer,
123120
text_featurizer=text_featurizer,
124-
stage="eval", cache=args.cache, shuffle=True
121+
stage="eval", cache=args.cache,
122+
shuffle=True, buffer_size=args.bfs,
125123
)
126124

127125
conformer_trainer = TransducerTrainerGA(

examples/conformer/train_subword_conformer.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,31 @@
2626

2727
parser = argparse.ArgumentParser(prog="Conformer Training")
2828

29-
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
30-
help="The file path of model configuration file")
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3130

32-
parser.add_argument("--max_ckpts", type=int, default=10,
33-
help="Max number of checkpoints to keep")
31+
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
3432

35-
parser.add_argument("--tfrecords", default=False, action="store_true",
36-
help="Whether to use tfrecords")
33+
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")
3734

38-
parser.add_argument("--sentence_piece", default=False, action="store_true",
39-
help="Whether to use `SentencePiece` model")
35+
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")
4036

41-
parser.add_argument("--tbs", type=int, default=None,
42-
help="Train batch size per replica")
37+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
4338

44-
parser.add_argument("--ebs", type=int, default=None,
45-
help="Evaluation batch size per replica")
39+
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")
4640

47-
parser.add_argument("--devices", type=int, nargs="*", default=[0],
48-
help="Devices' ids to apply distributed training")
41+
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
4942

50-
parser.add_argument("--mxp", default=False, action="store_true",
51-
help="Enable mixed precision")
43+
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
5244

53-
parser.add_argument("--cache", default=False, action="store_true",
54-
help="Enable caching for dataset")
45+
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
5546

56-
parser.add_argument("--subwords", type=str, default=None,
57-
help="Path to file that stores generated subwords")
47+
parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")
5848

59-
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
60-
help="Transcript files for generating subwords")
49+
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
50+
51+
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
52+
53+
parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")
6154

6255
args = parser.parse_args()
6356

@@ -97,28 +90,34 @@
9790
speech_featurizer=speech_featurizer,
9891
text_featurizer=text_featurizer,
9992
augmentations=config.learning_config.augmentations,
100-
stage="train", cache=args.cache, shuffle=True
93+
tfrecords_shards=args.tfrecords_shards,
94+
stage="train", cache=args.cache,
95+
shuffle=True, buffer_size=args.bfs,
10196
)
10297
eval_dataset = ASRTFRecordDataset(
10398
data_paths=config.learning_config.dataset_config.eval_paths,
10499
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
100+
tfrecords_shards=args.tfrecords_shards,
105101
speech_featurizer=speech_featurizer,
106102
text_featurizer=text_featurizer,
107-
stage="eval", cache=args.cache, shuffle=True
103+
stage="eval", cache=args.cache,
104+
shuffle=True, buffer_size=args.bfs,
108105
)
109106
else:
110107
train_dataset = ASRSliceDataset(
111108
data_paths=config.learning_config.dataset_config.train_paths,
112109
speech_featurizer=speech_featurizer,
113110
text_featurizer=text_featurizer,
114111
augmentations=config.learning_config.augmentations,
115-
stage="train", cache=args.cache, shuffle=True
112+
stage="train", cache=args.cache,
113+
shuffle=True, buffer_size=args.bfs,
116114
)
117115
eval_dataset = ASRSliceDataset(
118116
data_paths=config.learning_config.dataset_config.eval_paths,
119117
speech_featurizer=speech_featurizer,
120118
text_featurizer=text_featurizer,
121-
stage="eval", cache=args.cache, shuffle=True
119+
stage="eval", cache=args.cache,
120+
shuffle=True, buffer_size=args.bfs,
122121
)
123122

124123
conformer_trainer = TransducerTrainer(

0 commit comments

Comments
 (0)