Skip to content

Commit 8c32349

Browse files
committed
enable shuffle in dataloader.
1 parent 55e9750 commit 8c32349

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

egs/aishell/s10/chain/egs_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from common import splice_feats
1818

1919

20-
def get_egs_dataloader(egs_dir, egs_left_context, egs_right_context):
20+
def get_egs_dataloader(egs_dir,
21+
egs_left_context,
22+
egs_right_context,
23+
shuffle=True):
2124

2225
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
2326
frame_subsampling_factor = 3
@@ -32,6 +35,7 @@ def get_egs_dataloader(egs_dir, egs_left_context, egs_right_context):
3235

3336
dataloader = DataLoader(dataset,
3437
batch_size=batch_size,
38+
shuffle=shuffle,
3539
num_workers=0,
3640
collate_fn=collate_fn)
3741
return dataloader

egs/aishell/s10/chain/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def main():
181181

182182
dataloader = get_egs_dataloader(egs_dir=args.cegs_dir,
183183
egs_left_context=args.egs_left_context,
184-
egs_right_context=args.egs_right_context)
184+
egs_right_context=args.egs_right_context,
185+
shuffle=True)
185186

186187
optimizer = optim.Adam(model.parameters(),
187188
lr=learning_rate,

0 commit comments

Comments
 (0)