44# Apache 2.0
55
66import os
7-
7+ import math
88import numpy as np
99import torch
1010
2222def get_feat_dataloader (feats_scp ,
2323 model_left_context ,
2424 model_right_context ,
25+ frames_per_chunk = 51 ,
26+ ivector_scp = None ,
27+ ivector_period = 10 ,
2528 batch_size = 16 ,
2629 num_workers = 10 ):
27- dataset = FeatDataset (feats_scp = feats_scp )
30+ dataset = FeatDataset (feats_scp = feats_scp , ivector_scp = ivector_scp )
2831
2932 collate_fn = FeatDatasetCollateFunc (model_left_context = model_left_context ,
3033 model_right_context = model_right_context ,
31- frame_subsampling_factor = 3 )
34+ frame_subsampling_factor = 3 ,
35+ frames_per_chunk = frames_per_chunk ,
36+ ivector_period = ivector_period )
3237
3338 dataloader = DataLoader (dataset ,
3439 batch_size = batch_size ,
@@ -55,21 +60,40 @@ def _add_model_left_right_context(x, left_context, right_context):
5560
5661class FeatDataset (Dataset ):
5762
58- def __init__ (self , feats_scp ):
63+ def __init__ (self , feats_scp , ivector_scp = None ):
5964 assert os .path .isfile (feats_scp )
65+ if ivector_scp :
66+ assert os .path .isfile (ivector_scp )
6067
6168 self .feats_scp = feats_scp
6269
63- # items is a list of [key, rxfilename]
64- items = list ()
70+ # items is a dict of [uttid, feat_rxfilename, None]
71+ # or [uttid, feat_rxfilename, ivector_rxfilename] if ivector_scp is not None
72+ items = dict ()
6573
6674 with open (feats_scp , 'r' ) as f :
6775 for line in f :
6876 split = line .split ()
6977 assert len (split ) == 2
70- items .append (split )
71-
72- self .items = items
78+ uttid , rxfilename = split
79+ assert uttid not in items
80+ items [uttid ] = [uttid , rxfilename , None ]
81+ self .ivector_scp = None
82+ if ivector_scp :
83+ self .ivector_scp = ivector_scp
84+ expected_count = len (items )
85+ n = 0
86+ with open (ivector_scp , 'r' ) as f :
87+ for line in f :
88+ uttid_rxfilename = line .split ()
89+ assert len (uttid_rxfilename ) == 2
90+ uttid , rxfilename = uttid_rxfilename
91+ assert uttid in items
92+ items [uttid ][- 1 ] = rxfilename
93+ n += 1
94+ assert n == expected_count
95+
96+ self .items = list (items .values ())
7397
7498 self .num_items = len (self .items )
7599
@@ -81,6 +105,8 @@ def __getitem__(self, i):
81105
82106 def __str__ (self ):
83107 s = 'feats scp: {}\n ' .format (self .feats_scp )
108+ if self .ivector_scp :
109+ s += 'ivector_scp scp: {}\n ' .format (self .ivector_scp )
84110 s += 'num utt: {}\n ' .format (self .num_items )
85111 return s
86112
@@ -90,26 +116,37 @@ class FeatDatasetCollateFunc:
90116 def __init__ (self ,
91117 model_left_context ,
92118 model_right_context ,
93- frame_subsampling_factor = 3 ):
119+ frame_subsampling_factor = 3 ,
120+ frames_per_chunk = 51 ,
121+ ivector_period = 10 ):
94122 '''
95123 We need `frame_subsampling_factor` because we want to know
96124 the number of output frames of different waves in the same batch
97125 '''
98126 self .model_left_context = model_left_context
99127 self .model_right_context = model_right_context
100128 self .frame_subsampling_factor = frame_subsampling_factor
129+ self .frames_per_chunk = frames_per_chunk
130+ self .ivector_period = ivector_period
101131
102132 def __call__ (self , batch ):
103133 '''
104134 batch is a list of [key, rxfilename]
105135 '''
106136 key_list = []
107137 feat_list = []
138+ ivector_list = []
139+ ivector_len_list = []
108140 output_len_list = []
141+ subsampled_frames_per_chunk = (self .frames_per_chunk //
142+ self .frame_subsampling_factor )
109143 for b in batch :
110- key , rxfilename = b
144+ key , rxfilename , ivector_rxfilename = b
111145 key_list .append (key )
112146 feat = kaldi .read_mat (rxfilename ).numpy ()
147+ if ivector_rxfilename :
148+ ivector = kaldi .read_mat (
149+ ivector_rxfilename ).numpy () # L // 10 * C
113150 feat_len = feat .shape [0 ]
114151 output_len = (feat_len + self .frame_subsampling_factor -
115152 1 ) // self .frame_subsampling_factor
@@ -118,12 +155,33 @@ def __call__(self, batch):
118155 feat = _add_model_left_right_context (feat , self .model_left_context ,
119156 self .model_right_context )
120157 feat = splice_feats (feat )
121- feat_list .append (feat )
122- # no need to sort the feat by length
123158
124- # the user should sort utterances by length offline
125- # to avoid unnecessary padding
159+ # now we split feat to chunk, then we can do decode by chunk
160+ input_num_frames = (feat .shape [0 ] + 2
161+ - self .model_left_context - self .model_right_context )
162+ for i in range (0 , output_len , subsampled_frames_per_chunk ):
163+ # input len:418 -> output len:140 -> output chunk:[0, 17, 34, 51, 68, 85, 102, 119, 136]
164+ first_output = i * self .frame_subsampling_factor
165+ last_output = min (input_num_frames ,
166+ first_output + (subsampled_frames_per_chunk - 1 ) * self .frame_subsampling_factor )
167+ first_input = first_output
168+ last_input = last_output + self .model_left_context + self .model_right_context
169+ input_x = feat [first_input :last_input + 1 , :]
170+ if ivector_rxfilename :
171+ ivector_index = (
172+ first_output + last_output ) // 2 // self .ivector_period
173+ input_ivector = ivector [ivector_index , :].reshape (1 , - 1 )
174+ feat_list .append (np .concatenate ((input_x ,
175+ np .repeat (input_ivector , input_x .shape [0 ], axis = 0 )),
176+ axis = - 1 ))
177+ else :
178+ feat_list .append (input_x )
179+
126180 padded_feat = pad_sequence (
127181 [torch .from_numpy (feat ).float () for feat in feat_list ],
128182 batch_first = True )
183+
184+ assert sum ([math .ceil (l / subsampled_frames_per_chunk ) for l in output_len_list ]) \
185+ == padded_feat .shape [0 ]
186+
129187 return key_list , padded_feat , output_len_list
0 commit comments