Skip to content

Commit 63c732b

Browse files
authored
[scripts,egs] Support ivector training in pytorch model (#3969)
1 parent 261eee1 commit 63c732b

File tree

10 files changed

+293
-59
lines changed

10 files changed

+293
-59
lines changed

egs/aishell/s10/chain/egs_dataset.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,16 @@ def get_egs_dataloader(egs_dir_or_scp,
4141
sampler = torch.utils.data.distributed.DistributedSampler(
4242
dataset, num_replicas=world_size, rank=local_rank, shuffle=True)
4343
dataloader = DataLoader(dataset,
44-
batch_size=batch_size,
45-
collate_fn=collate_fn,
46-
sampler=sampler)
44+
batch_size=batch_size,
45+
collate_fn=collate_fn,
46+
sampler=sampler)
4747
else:
48-
base_sampler = torch.utils.data.RandomSampler(dataset)
49-
sampler = torch.utils.data.BatchSampler(base_sampler, batch_size, False)
50-
dataloader = DataLoader(dataset,
51-
batch_sampler=sampler,
52-
collate_fn=collate_fn)
48+
base_sampler = torch.utils.data.RandomSampler(dataset)
49+
sampler = torch.utils.data.BatchSampler(
50+
base_sampler, batch_size, False)
51+
dataloader = DataLoader(dataset,
52+
batch_sampler=sampler,
53+
collate_fn=collate_fn)
5354
return dataloader
5455

5556

@@ -146,18 +147,21 @@ def __call__(self, batch):
146147

147148
batch_size = supervision.num_sequences
148149

149-
frames_per_sequence = (supervision.frames_per_sequence * \
150-
self.frame_subsampling_factor) + \
151-
self.egs_left_context + self.egs_right_context
150+
frames_per_sequence = (supervision.frames_per_sequence *
151+
self.frame_subsampling_factor) + \
152+
self.egs_left_context + self.egs_right_context
152153

153-
# TODO(fangjun): support ivector
154-
assert len(eg.inputs) == 1
155154
assert eg.inputs[0].name == 'input'
156155

157156
_feats = kaldi.FloatMatrix()
158157
eg.inputs[0].features.GetMatrix(_feats)
159158
feats = _feats.numpy()
160159

160+
if len(eg.inputs) > 1:
161+
_ivectors = kaldi.FloatMatrix()
162+
eg.inputs[1].features.GetMatrix(_ivectors)
163+
ivectors = _ivectors.numpy()
164+
161165
assert feats.shape[0] == batch_size * frames_per_sequence
162166

163167
feat_list = []
@@ -173,6 +177,11 @@ def __call__(self, batch):
173177
end_index -= 1 # remove the rightmost frame added for frame shift
174178
feat = feats[start_index:end_index:, :]
175179
feat = splice_feats(feat)
180+
if len(eg.inputs) > 1:
181+
repeat_ivector = torch.from_numpy(
182+
ivectors[i]).repeat(feat.shape[0], 1)
183+
feat = torch.cat(
184+
(torch.from_numpy(feat), repeat_ivector), dim=1).numpy()
176185
feat_list.append(feat)
177186

178187
batched_feat = np.stack(feat_list, axis=0)
@@ -182,7 +191,11 @@ def __call__(self, batch):
182191
# the first -2 is from extra left/right context
183192
# the second -2 is from lda feats splicing
184193
assert batched_feat.shape[1] == frames_per_sequence - 4
185-
assert batched_feat.shape[2] == feats.shape[-1] * 3
194+
if len(eg.inputs) > 1:
195+
assert batched_feat.shape[2] == feats.shape[-1] * \
196+
3 + ivectors.shape[-1]
197+
else:
198+
assert batched_feat.shape[2] == feats.shape[-1] * 3
186199

187200
torch_feat = torch.from_numpy(batched_feat).float()
188201
feature_list.append(torch_feat)
@@ -222,8 +235,8 @@ def _test_nnet_chain_example_dataset():
222235
for b in dataloader:
223236
key_list, feature_list, supervision_list = b
224237
assert feature_list[0].shape == (128, 204, 129) \
225-
or feature_list[0].shape == (128, 144, 129) \
226-
or feature_list[0].shape == (128, 165, 129)
238+
or feature_list[0].shape == (128, 144, 129) \
239+
or feature_list[0].shape == (128, 165, 129)
227240
assert supervision_list[0].weight == 1
228241
supervision_list[0].num_sequences == 128 # minibach size is 128
229242

egs/aishell/s10/chain/feat_dataset.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Apache 2.0
55

66
import os
7-
7+
import math
88
import numpy as np
99
import torch
1010

@@ -22,13 +22,18 @@
2222
def get_feat_dataloader(feats_scp,
2323
model_left_context,
2424
model_right_context,
25+
frames_per_chunk=51,
26+
ivector_scp=None,
27+
ivector_period=10,
2528
batch_size=16,
2629
num_workers=10):
27-
dataset = FeatDataset(feats_scp=feats_scp)
30+
dataset = FeatDataset(feats_scp=feats_scp, ivector_scp=ivector_scp)
2831

2932
collate_fn = FeatDatasetCollateFunc(model_left_context=model_left_context,
3033
model_right_context=model_right_context,
31-
frame_subsampling_factor=3)
34+
frame_subsampling_factor=3,
35+
frames_per_chunk=frames_per_chunk,
36+
ivector_period=ivector_period)
3237

3338
dataloader = DataLoader(dataset,
3439
batch_size=batch_size,
@@ -55,21 +60,40 @@ def _add_model_left_right_context(x, left_context, right_context):
5560

5661
class FeatDataset(Dataset):
5762

58-
def __init__(self, feats_scp):
63+
def __init__(self, feats_scp, ivector_scp=None):
5964
assert os.path.isfile(feats_scp)
65+
if ivector_scp:
66+
assert os.path.isfile(ivector_scp)
6067

6168
self.feats_scp = feats_scp
6269

63-
# items is a list of [key, rxfilename]
64-
items = list()
70+
# items is a dict of [uttid, feat_rxfilename, None]
71+
# or [uttid, feat_rxfilename, ivector_rxfilename] if ivector_scp is not None
72+
items = dict()
6573

6674
with open(feats_scp, 'r') as f:
6775
for line in f:
6876
split = line.split()
6977
assert len(split) == 2
70-
items.append(split)
71-
72-
self.items = items
78+
uttid, rxfilename = split
79+
assert uttid not in items
80+
items[uttid] = [uttid, rxfilename, None]
81+
self.ivector_scp = None
82+
if ivector_scp:
83+
self.ivector_scp = ivector_scp
84+
expected_count = len(items)
85+
n = 0
86+
with open(ivector_scp, 'r') as f:
87+
for line in f:
88+
uttid_rxfilename = line.split()
89+
assert len(uttid_rxfilename) == 2
90+
uttid, rxfilename = uttid_rxfilename
91+
assert uttid in items
92+
items[uttid][-1] = rxfilename
93+
n += 1
94+
assert n == expected_count
95+
96+
self.items = list(items.values())
7397

7498
self.num_items = len(self.items)
7599

@@ -81,6 +105,8 @@ def __getitem__(self, i):
81105

82106
def __str__(self):
83107
s = 'feats scp: {}\n'.format(self.feats_scp)
108+
if self.ivector_scp:
109+
s += 'ivector_scp scp: {}\n'.format(self.ivector_scp)
84110
s += 'num utt: {}\n'.format(self.num_items)
85111
return s
86112

@@ -90,26 +116,37 @@ class FeatDatasetCollateFunc:
90116
def __init__(self,
91117
model_left_context,
92118
model_right_context,
93-
frame_subsampling_factor=3):
119+
frame_subsampling_factor=3,
120+
frames_per_chunk=51,
121+
ivector_period=10):
94122
'''
95123
We need `frame_subsampling_factor` because we want to know
96124
the number of output frames of different waves in the same batch
97125
'''
98126
self.model_left_context = model_left_context
99127
self.model_right_context = model_right_context
100128
self.frame_subsampling_factor = frame_subsampling_factor
129+
self.frames_per_chunk = frames_per_chunk
130+
self.ivector_period = ivector_period
101131

102132
def __call__(self, batch):
103133
'''
104134
batch is a list of [key, rxfilename]
105135
'''
106136
key_list = []
107137
feat_list = []
138+
ivector_list = []
139+
ivector_len_list = []
108140
output_len_list = []
141+
subsampled_frames_per_chunk = (self.frames_per_chunk //
142+
self.frame_subsampling_factor)
109143
for b in batch:
110-
key, rxfilename = b
144+
key, rxfilename, ivector_rxfilename = b
111145
key_list.append(key)
112146
feat = kaldi.read_mat(rxfilename).numpy()
147+
if ivector_rxfilename:
148+
ivector = kaldi.read_mat(
149+
ivector_rxfilename).numpy() # L // 10 * C
113150
feat_len = feat.shape[0]
114151
output_len = (feat_len + self.frame_subsampling_factor -
115152
1) // self.frame_subsampling_factor
@@ -118,12 +155,33 @@ def __call__(self, batch):
118155
feat = _add_model_left_right_context(feat, self.model_left_context,
119156
self.model_right_context)
120157
feat = splice_feats(feat)
121-
feat_list.append(feat)
122-
# no need to sort the feat by length
123158

124-
# the user should sort utterances by length offline
125-
# to avoid unnecessary padding
159+
# now we split feat to chunk, then we can do decode by chunk
160+
input_num_frames = (feat.shape[0] + 2
161+
- self.model_left_context - self.model_right_context)
162+
for i in range(0, output_len, subsampled_frames_per_chunk):
163+
# input len:418 -> output len:140 -> output chunk:[0, 17, 34, 51, 68, 85, 102, 119, 136]
164+
first_output = i * self.frame_subsampling_factor
165+
last_output = min(input_num_frames,
166+
first_output + (subsampled_frames_per_chunk-1) * self.frame_subsampling_factor)
167+
first_input = first_output
168+
last_input = last_output + self.model_left_context + self.model_right_context
169+
input_x = feat[first_input:last_input+1, :]
170+
if ivector_rxfilename:
171+
ivector_index = (
172+
first_output + last_output) // 2 // self.ivector_period
173+
input_ivector = ivector[ivector_index, :].reshape(1, -1)
174+
feat_list.append(np.concatenate((input_x,
175+
np.repeat(input_ivector, input_x.shape[0], axis=0)),
176+
axis=-1))
177+
else:
178+
feat_list.append(input_x)
179+
126180
padded_feat = pad_sequence(
127181
[torch.from_numpy(feat).float() for feat in feat_list],
128182
batch_first=True)
183+
184+
assert sum([math.ceil(l / subsampled_frames_per_chunk) for l in output_len_list]) \
185+
== padded_feat.shape[0]
186+
129187
return key_list, padded_feat, output_len_list

egs/aishell/s10/chain/inference.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import os
88
import sys
9+
import math
910

1011
import torch
1112
from torch.utils.dlpack import to_dlpack
@@ -38,6 +39,7 @@ def main():
3839
model = get_chain_model(
3940
feat_dim=args.feat_dim,
4041
output_dim=args.output_dim,
42+
ivector_dim=args.ivector_dim,
4143
lda_mat_filename=args.lda_mat_filename,
4244
hidden_dim=args.hidden_dim,
4345
bottleneck_dim=args.bottleneck_dim,
@@ -64,22 +66,29 @@ def main():
6466

6567
dataloader = get_feat_dataloader(
6668
feats_scp=args.feats_scp,
69+
ivector_scp=args.ivector_scp,
6770
model_left_context=args.model_left_context,
6871
model_right_context=args.model_right_context,
69-
batch_size=32)
70-
72+
batch_size=32,
73+
num_workers=10)
74+
subsampling_factor = 3
75+
subsampled_frames_per_chunk = args.frames_per_chunk // subsampling_factor
7176
for batch_idx, batch in enumerate(dataloader):
7277
key_list, padded_feat, output_len_list = batch
7378
padded_feat = padded_feat.to(device)
7479
with torch.no_grad():
7580
nnet_output, _ = model(padded_feat)
7681

7782
num = len(key_list)
83+
first = 0
7884
for i in range(num):
7985
key = key_list[i]
8086
output_len = output_len_list[i]
81-
value = nnet_output[i, :output_len, :]
87+
target_len = math.ceil(output_len / subsampled_frames_per_chunk)
88+
result = nnet_output[first:first + target_len, :, :].split(1, 0)
89+
value = torch.cat(result, dim=1)[0, :output_len, :]
8290
value = value.cpu()
91+
first += target_len
8392

8493
m = kaldi.SubMatrixFromDLPack(to_dlpack(value))
8594
m = Matrix(m)

egs/aishell/s10/chain/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
def get_chain_model(feat_dim,
1919
output_dim,
20+
ivector_dim,
2021
hidden_dim,
2122
bottleneck_dim,
2223
prefinal_bottleneck_dim,
@@ -25,6 +26,7 @@ def get_chain_model(feat_dim,
2526
lda_mat_filename=None):
2627
model = ChainModel(feat_dim=feat_dim,
2728
output_dim=output_dim,
29+
ivector_dim=ivector_dim,
2830
lda_mat_filename=lda_mat_filename,
2931
hidden_dim=hidden_dim,
3032
bottleneck_dim=bottleneck_dim,
@@ -82,6 +84,7 @@ class ChainModel(nn.Module):
8284
def __init__(self,
8385
feat_dim,
8486
output_dim,
87+
ivector_dim=0,
8588
lda_mat_filename=None,
8689
hidden_dim=1024,
8790
bottleneck_dim=128,
@@ -97,8 +100,9 @@ def __init__(self,
97100
assert len(kernel_size_list) == len(subsampling_factor_list)
98101
num_layers = len(kernel_size_list)
99102

103+
input_dim = feat_dim * 3 + ivector_dim
100104
# tdnn1_affine requires [N, T, C]
101-
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
105+
self.tdnn1_affine = nn.Linear(in_features=input_dim,
102106
out_features=hidden_dim)
103107

104108
# tdnn1_batchnorm requires [N, C, T]
@@ -142,11 +146,11 @@ def __init__(self,
142146
if lda_mat_filename:
143147
logging.info('Use LDA from {}'.format(lda_mat_filename))
144148
self.lda_A, self.lda_b = load_lda_mat(lda_mat_filename)
145-
assert feat_dim * 3 == self.lda_A.shape[0]
149+
assert input_dim == self.lda_A.shape[0]
146150
self.has_LDA = True
147151
else:
148152
logging.info('replace LDA with BatchNorm')
149-
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3,
153+
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim,
150154
affine=False)
151155
self.has_LDA = False
152156

egs/aishell/s10/chain/options.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@ def _set_inference_args(parser):
2727
dest='feats_scp',
2828
help='feats.scp filename, required for inference',
2929
type=str)
30+
31+
parser.add_argument('--frames-per-chunk',
32+
dest='frames_per_chunk',
33+
help='frames per chunk',
34+
type=int,
35+
default=51)
36+
37+
parser.add_argument('--ivector-scp',
38+
dest='ivector_scp',
39+
help='ivector.scp filename, required for ivector inference',
40+
type=str)
41+
42+
parser.add_argument('--ivector-period',
43+
dest='ivector_period',
44+
help='ivector period',
45+
type=int,
46+
default=10)
3047

3148
parser.add_argument('--model-left-context',
3249
dest='model_left_context',
@@ -228,10 +245,17 @@ def get_args():
228245

229246
parser.add_argument('--feat-dim',
230247
dest='feat_dim',
231-
help='nn input dimension',
248+
help='nn input 0 dimension',
232249
required=True,
233250
type=int)
234251

252+
parser.add_argument('--ivector-dim',
253+
dest='ivector_dim',
254+
help='nn input 1 dimension',
255+
required=False,
256+
default=0,
257+
type=int)
258+
235259
parser.add_argument('--output-dim',
236260
dest='output_dim',
237261
help='nn output dimension',

0 commit comments

Comments
 (0)