Skip to content

Commit a634a5c

Browse files
authored
support delta in pytorch training and get_raw_egs by chain2 (#3991)
1 parent 44c8805 commit a634a5c

File tree

4 files changed

+108
-70
lines changed

4 files changed

+108
-70
lines changed

egs/aishell/s10/chain/egs_dataset.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from common import splice_feats
1919

20-
2120
def get_egs_dataloader(egs_dir_or_scp,
2221
egs_left_context,
2322
egs_right_context,
@@ -173,10 +172,9 @@ def __call__(self, batch):
173172

174173
end_index = start_index + frames_per_sequence
175174

176-
start_index += 1 # remove the leftmost frame added for frame shift
177-
end_index -= 1 # remove the rightmost frame added for frame shift
175+
start_index += 2 # remove the leftmost frame added for frame shift
176+
end_index -= 2 # remove the rightmost frame added for frame shift
178177
feat = feats[start_index:end_index:, :]
179-
feat = splice_feats(feat)
180178
if len(eg.inputs) > 1:
181179
repeat_ivector = torch.from_numpy(
182180
ivectors[i]).repeat(feat.shape[0], 1)
@@ -192,10 +190,9 @@ def __call__(self, batch):
192190
# the second -2 is from lda feats splicing
193191
assert batched_feat.shape[1] == frames_per_sequence - 4
194192
if len(eg.inputs) > 1:
195-
assert batched_feat.shape[2] == feats.shape[-1] * \
196-
3 + ivectors.shape[-1]
193+
assert batched_feat.shape[2] == feats.shape[-1] + ivectors.shape[-1]
197194
else:
198-
assert batched_feat.shape[2] == feats.shape[-1] * 3
195+
assert batched_feat.shape[2] == feats.shape[-1]
199196

200197
torch_feat = torch.from_numpy(batched_feat).float()
201198
feature_list.append(torch_feat)
@@ -204,7 +201,7 @@ def __call__(self, batch):
204201

205202

206203
def _test_nnet_chain_example_dataset():
207-
egs_dir = 'exp/chain/merged_egs'
204+
egs_dir = 'exp/chain_pybind/tdnnivector_delta_sp/merged_egs_chain2'
208205
dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir)
209206
egs_left_context = 29
210207
egs_right_context = 29

egs/aishell/s10/chain/feat_dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,9 @@ def __call__(self, batch):
154154
# now add model left and right context
155155
feat = _add_model_left_right_context(feat, self.model_left_context,
156156
self.model_right_context)
157-
feat = splice_feats(feat)
158157

159158
# now we split feat to chunk, then we can do decode by chunk
160-
input_num_frames = (feat.shape[0] + 2
161-
- self.model_left_context - self.model_right_context)
159+
input_num_frames = feat.shape[0] - self.model_left_context - self.model_right_context
162160
for i in range(0, output_len, subsampled_frames_per_chunk):
163161
# input len:418 -> output len:140 -> output chunk:[0, 17, 34, 51, 68, 85, 102, 119, 136]
164162
first_output = i * self.frame_subsampling_factor

egs/aishell/s10/local/run_chain.sh

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ online_cmvn=true
1919
train_ivector=false
2020
num_epochs=6
2121
dropout_schedule=0,[email protected],[email protected],0 # you might set this to 0,0 or 0.5,0.5 to train.
22+
frame_subsampling_factor=3
23+
feat_type=delta
24+
lang=default
2225

2326
. ./path.sh
2427
. ./cmd.sh
@@ -123,20 +126,23 @@ if [[ $stage -le 12 ]]; then
123126
fi
124127

125128
if [[ $stage -le 13 ]]; then
126-
echo "$0: creating phone language-model"
127-
$mkgraph_cmd $dir/log/make_phone_lm.log \
128-
chain-est-phone-lm \
129-
"ark:gunzip -c $tree_dir/ali.*.gz | ali-to-phones $tree_dir/final.mdl ark:- ark:- |" \
130-
$dir/phone_lm.fst || exit 1
131-
fi
129+
echo "$0: Making Phone LM and denominator and normalization FST"
130+
mkdir -p $dir/den_fsts/log
131+
132+
# We may later reorganize this.
133+
cp $tree_dir/tree $dir/${lang}.tree
132134

133-
if [[ $stage -le 14 ]]; then
134-
echo "creating denominator FST"
135-
copy-transition-model $tree_dir/final.mdl $dir/0.trans_mdl
136-
cp $tree_dir/tree $dir
137-
$train_cmd $dir/log/make_den_fst.log \
138-
chain-make-den-fst $dir/tree $dir/0.trans_mdl $dir/phone_lm.fst \
139-
$dir/den.fst $dir/normalization.fst || exit 1
135+
echo "$0: creating phone language-model"
136+
$train_cmd $dir/den_fsts/log/make_phone_lm_${lang}.log \
137+
chain-est-phone-lm --num-extra-lm-states=2000 \
138+
"ark:gunzip -c $tree_dir/ali.*.gz | ali-to-phones $tree_dir/final.mdl ark:- ark:- |" \
139+
$dir/den_fsts/${lang}.phone_lm.fst
140+
mkdir -p $dir/init
141+
copy-transition-model $tree_dir/final.mdl $dir/init/${lang}_trans.mdl
142+
echo "$0: creating denominator FST"
143+
$train_cmd $dir/den_fsts/log/make_den_fst.log \
144+
chain-make-den-fst $dir/${lang}.tree $dir/init/${lang}_trans.mdl $dir/den_fsts/${lang}.phone_lm.fst \
145+
$dir/den_fsts/${lang}.den.fst $dir/den_fsts/${lang}.normalization.fst || exit 1;
140146
fi
141147

142148
# You should know how to calculate your model's left/right context **manually**
@@ -159,57 +165,63 @@ log_level=info # valid values: debug, info, warning
159165
# true to save network output as kaldi::CompressedMatrix
160166
# false to save it as kaldi::Matrix<float>
161167
save_nn_output_as_compressed=false
162-
163-
if [[ $stage -le 15 ]]; then
164-
echo "$0: generating egs"
165-
steps/nnet3/chain/get_egs.sh \
166-
--alignment-subsampling-factor 3 \
167-
--cmd "$train_cmd" \
168+
if [ $stage -le 14 ]; then
169+
echo "$0: about to dump raw egs."
170+
# Dump raw egs.
171+
steps/chain2/get_raw_egs.sh --cmd "$train_cmd" \
172+
--lang "${lang}" \
168173
--online-cmvn $online_cmvn \
169174
--online-ivector-dir "$train_ivector_dir" \
170-
--frame-subsampling-factor 3 \
171-
--frames-overlap-per-eg 0 \
172-
--frames-per-eg $frames_per_eg \
173-
--frames-per-iter $frames_per_iter \
174-
--generate-egs-scp true \
175175
--left-context $egs_left_context \
176-
--left-context-initial -1 \
177-
--left-tolerance 5 \
178176
--right-context $egs_right_context \
179-
--right-context-final -1 \
180-
--right-tolerance 5 \
181-
--srand 0 \
182-
--stage -10 \
183-
$train_data_dir \
184-
$dir $lat_dir $dir/egs
177+
--frame-subsampling-factor $frame_subsampling_factor \
178+
--alignment-subsampling-factor $frame_subsampling_factor \
179+
--frames-per-chunk $frames_per_eg \
180+
--feat-type $feat_type \
181+
${train_data_dir} ${dir} ${lat_dir} ${dir}/raw_egs
185182
fi
186183

184+
if [ $stage -le 15 ]; then
185+
echo "$0: about to process egs"
186+
steps/chain2/process_egs.sh --cmd "$train_cmd" \
187+
--num-repeats 1 \
188+
${dir}/raw_egs ${dir}/processed_egs
189+
fi
187190

188-
if [[ $stage -le 16 ]]; then
189-
echo "$0: merging egs"
190-
mkdir -p $dir/merged_egs
191-
num_egs=$(ls -1 $dir/egs/cegs*.ark | wc -l)
192-
193-
$train_cmd --max-jobs-run $nj JOB=1:$num_egs $dir/merged_egs/log/merge_egs.JOB.log \
194-
nnet3-chain-shuffle-egs ark:$dir/egs/cegs.JOB.ark ark:- \| \
195-
nnet3-chain-merge-egs --minibatch-size=$minibatch_size ark:- \
196-
ark,scp:$dir/merged_egs/cegs.JOB.ark,$dir/merged_egs/cegs.JOB.scp || exit 1
197-
198-
rm $dir/egs/cegs.*.ark
191+
if [ $stage -le 16 ]; then
192+
echo "$0: about to randomize egs"
193+
steps/chain2/randomize_egs.sh --frames-per-job 3000000 \
194+
${dir}/processed_egs ${dir}/egs
199195
fi
200196

201-
feat_dim=$(cat $dir/egs/info/feat_dim)
197+
info_file=$dir/raw_egs/info.txt
198+
feat_dim=$(cat $info_file | grep 'feat_dim' | awk '{print $NF}')
202199
ivector_dim=0
203200
ivector_period=0
204201
if $train_ivector; then
205-
ivector_dim=$(cat $dir/egs/info/ivector_dim)
202+
ivector_dim=$(cat $info_file | grep 'ivector_dim' | awk '{print $NF}')
206203
ivector_period=$(cat $train_ivector_dir/ivector_period)
207204
fi
208205
echo "ivector_dim: $ivector_dim", "ivector_period, $ivector_period"
209-
output_dim=$(cat $dir/egs/info/num_pdfs)
210206

207+
merged_egs_dir=merged_egs_chain2
208+
if [[ $stage -le 19 ]]; then
209+
echo "$0: merging egs"
210+
211+
mkdir -p $dir/$merged_egs_dir
212+
num_egs=$(ls -1 $dir/egs/train.*.scp | wc -l)
213+
214+
$train_cmd --max-jobs-run $nj JOB=1:$num_egs $dir/$merged_egs_dir/log/merge_egs.JOB.log \
215+
nnet3-chain-shuffle-egs scp:$dir/egs/train.JOB.scp ark:- \| \
216+
nnet3-chain-merge-egs --minibatch-size=$minibatch_size ark:- \
217+
ark,scp:$dir/$merged_egs_dir/cegs.JOB.ark,$dir/$merged_egs_dir/cegs.JOB.scp || exit 1
218+
219+
rm $dir/raw_egs/cegs.*.ark
220+
fi
221+
222+
output_dim=$(cat $info_file | grep 'num_leaves' | awk '{print $NF}')
211223
train_dir=train${train_affix}
212-
if [[ $stage -le 17 ]]; then
224+
if [[ $stage -le 20 ]]; then
213225
echo "$0: training..."
214226

215227
mkdir -p $dir/$train_dir/tensorboard
@@ -259,11 +271,11 @@ if [[ $stage -le 17 ]]; then
259271
--output-dim $output_dim \
260272
--prefinal-bottleneck-dim $prefinal_bottleneck_dim \
261273
--subsampling-factor-list "$subsampling_factor_list" \
262-
--train.cegs-dir $dir/merged_egs \
274+
--train.cegs-dir $dir/$merged_egs_dir \
263275
--train.ddp.init-method $init_method \
264276
--train.ddp.multiple-machine $use_multiple_machine \
265277
--train.ddp.world-size $world_size \
266-
--train.den-fst $dir/den.fst \
278+
--train.den-fst $dir/den_fsts/${lang}.den.fst \
267279
--train.dropout-schedule "$dropout_schedule" \
268280
--train.egs-left-context $egs_left_context \
269281
--train.egs-right-context $egs_right_context \
@@ -272,11 +284,11 @@ if [[ $stage -le 17 ]]; then
272284
--train.lr $lr \
273285
--train.num-epochs $num_epochs \
274286
--train.use-ddp $use_ddp \
275-
--train.valid-cegs-scp $dir/egs/valid_diagnostic.scp \
287+
--train.valid-cegs-scp $dir/egs/train_subset.scp \
276288
--train.xent-regularize 0.1 || exit 1;
277289
fi
278290

279-
if [[ $stage -le 18 ]]; then
291+
if [[ $stage -le 21 ]]; then
280292
echo "inference: computing likelihood"
281293
for x in test dev; do
282294
mkdir -p $dir/$train_dir/inference/$x
@@ -288,8 +300,11 @@ if [[ $stage -le 18 ]]; then
288300
fi
289301
feat_scp="data/${x}_hires/feats.scp"
290302
if $online_cmvn; then
291-
apply-cmvn-online --spk2utt=ark:data/${x}_hires/spk2utt $dir/egs/global_cmvn.stats \
292-
scp:data/${x}_hires/feats.scp ark,scp:data/${x}_hires/data/online_cmvn_feats.ark,data/${x}_hires/online_cmvn_feats.scp
303+
if [[ "$feat_type" == "delta" ]]; then
304+
apply-cmvn-online --spk2utt=ark:data/${x}_hires/spk2utt $dir/raw_egs/global_cmvn.stats \
305+
scp:data/${x}_hires/feats.scp ark:- | add-deltas --print-args=false --delta-order=2 --delta-window=2 \
306+
ark:- ark,scp:data/${x}_hires/data/online_cmvn_feats.ark,data/${x}_hires/online_cmvn_feats.scp
307+
fi
293308
feat_scp="data/${x}_hires/online_cmvn_feats.scp"
294309
fi
295310
best_epoch=$(cat $dir/$train_dir/best-epoch-info | grep 'best epoch' | awk '{print $NF}')
@@ -318,16 +333,17 @@ if [[ $stage -le 18 ]]; then
318333
done
319334
fi
320335

321-
if [[ $stage -le 19 ]]; then
336+
if [[ $stage -le 22 ]]; then
322337
# Note: it might appear that this $lang directory is mismatched, and it is as
323338
# far as the 'topo' is concerned, but this script doesn't read the 'topo' from
324339
# the lang directory.
325340
cp $tree_dir/final.mdl $dir/final.mdl
341+
cp $tree_dir/tree $dir/tree
326342
utils/mkgraph.sh --self-loop-scale 1.0 data/lang_test $dir $dir/graph
327343
fi
328344

329345

330-
if [[ $stage -le 20 ]]; then
346+
if [[ $stage -le 23 ]]; then
331347
echo "decoding"
332348
for x in test dev; do
333349
if [[ ! -f $dir/$train_dir/inference/$x/nnet_output.scp ]]; then
@@ -340,13 +356,13 @@ if [[ $stage -le 20 ]]; then
340356
./local/decode.sh \
341357
--nj $nj \
342358
$dir/graph \
343-
$dir/0.trans_mdl \
359+
$dir/init/${lang}_trans.mdl \
344360
$dir/$train_dir/inference/$x/nnet_output.scp \
345361
$dir/$train_dir/decode_res/$x
346362
done
347363
fi
348364

349-
if [[ $stage -le 21 ]]; then
365+
if [[ $stage -le 24 ]]; then
350366
echo "scoring"
351367

352368
for x in test dev; do

egs/wsj/s5/steps/nnet3/chain2/get_raw_egs.sh

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@ online_ivector_dir= # can be used if we are including speaker information as iV
6363
cmvn_opts= # can be used for specifying CMVN options, if feature type is not lda (if lda,
6464
# it doesn't make sense to use different options than were used as input to the
6565
# LDA transform). This is used to turn off CMVN in the online-nnet experiments.
66-
66+
online_cmvn=false # Set to 'true' to replace 'apply-cmvn' by 'apply-cmvn-online' in the nnet3 input.
67+
# The configuration is passed externally via '$cmvn_opts' given to train.py,
68+
# typically as: --cmvn-opts="--config conf/online_cmvn.conf".
69+
# The global_cmvn.stats are computed by this script from the features.
70+
# Note: the online cmvn for ivector extractor it is controlled separately in
71+
# steps/online/nnet2/train_ivector_extractor.sh by --online-cmvn-iextractor
72+
feat_type= # If supplied, the raw egs will generated by feat with 'delta'
6773
lattice_lm_scale= # If supplied, the graph/lm weight of the lattices will be
6874
# used (with this scale) in generating supervisions
6975
# This is 0 by default for conventional supervised training,
@@ -226,14 +232,34 @@ else
226232
fi
227233

228234
feats="scp:$sdata/JOB/feats.scp"
229-
if [ ! -z $cmvn_opts ]; then
235+
# get the global_cmvn stats for online-cmvn,
236+
if $online_cmvn; then
237+
# create global_cmvn.stats,
238+
#
239+
# caution: the top-level nnet training script should copy
240+
# 'global_cmvn.stats' and 'online_cmvn' to its own dir.
241+
if ! matrix-sum --binary=false scp:$data/cmvn.scp - >$dir/global_cmvn.stats 2>/dev/null; then
242+
echo "$0: Error summing cmvn stats"
243+
exit 1
244+
fi
245+
touch $dir/online_cmvn
246+
feats="ark,s,cs:apply-cmvn-online $cmvn_opts --spk2utt=ark:$sdata/JOB/spk2utt $dir/global_cmvn.stats scp:$sdata/JOB/feats.scp ark:- |"
247+
else
248+
[ -f $dir/online_cmvn ] && rm $dir/online_cmvn
249+
if [ ! -z $cmvn_opts ]; then
230250
if [ ! -f $data/cmvn.scp ]; then
231251
echo "Cannot find $data/cmvn.scp. But cmvn_opts=$cmvn_opts"
232252
exit 1
233253
fi
234254
if [ `echo $cmvn_opts | fgrep -c true` -eq 1 ]; then
235255
feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |"
236256
fi
257+
fi
258+
fi
259+
260+
delta_opts="--delta-order=2 --delta-window=2"
261+
if [[ $feat_type == "delta" ]]; then
262+
feats="$feats add-deltas $delta_opts ark:- ark:- |"
237263
fi
238264

239265
if [ $stage -le 0 ]; then
@@ -284,6 +310,7 @@ EOF
284310

285311
if [ ! -z "$online_ivector_dir" ]; then
286312
ivector_dim=$(feat-to-dim scp:$online_ivector_dir/ivector_online.scp -) || exit 1;
313+
mkdir -p $dir/info/
287314
echo $ivector_dim > $dir/info/ivector_dim
288315
echo ivector_dim $ivector_dim >> $dir/info.txt
289316
echo final.ie.id `cat $online_ivector_dir/final.ie.id` >> $dir/info.txt

0 commit comments

Comments
 (0)