From d39e3cdd34bbbf861a1a0fba43cd32db9f6b90a2 Mon Sep 17 00:00:00 2001 From: fanlu Date: Mon, 2 Mar 2020 21:48:18 +0800 Subject: [PATCH 1/4] support ivector training in pytorch model --- egs/aishell/s10/chain/egs_dataset.py | 14 ++++++- egs/aishell/s10/chain/feat_dataset.py | 46 +++++++++++++++++----- egs/aishell/s10/chain/inference.py | 32 ++++++++++++--- egs/aishell/s10/chain/model.py | 10 +++-- egs/aishell/s10/chain/options.py | 26 ++++++++++++- egs/aishell/s10/chain/train.py | 1 + egs/aishell/s10/conf/online_cmvn.conf | 1 + egs/aishell/s10/local/run_chain.sh | 56 ++++++++++++++++++--------- egs/aishell/s10/run.sh | 1 + 9 files changed, 148 insertions(+), 39 deletions(-) create mode 100644 egs/aishell/s10/conf/online_cmvn.conf diff --git a/egs/aishell/s10/chain/egs_dataset.py b/egs/aishell/s10/chain/egs_dataset.py index 22bfd1333d6..97eae03b992 100755 --- a/egs/aishell/s10/chain/egs_dataset.py +++ b/egs/aishell/s10/chain/egs_dataset.py @@ -151,13 +151,17 @@ def __call__(self, batch): self.egs_left_context + self.egs_right_context # TODO(fangjun): support ivector - assert len(eg.inputs) == 1 assert eg.inputs[0].name == 'input' _feats = kaldi.FloatMatrix() eg.inputs[0].features.GetMatrix(_feats) feats = _feats.numpy() + if len(eg.inputs) > 1: + _ivectors = kaldi.FloatMatrix() + eg.inputs[1].features.GetMatrix(_ivectors) + ivectors = _ivectors.numpy() + assert feats.shape[0] == batch_size * frames_per_sequence feat_list = [] @@ -173,6 +177,9 @@ def __call__(self, batch): end_index -= 1 # remove the rightmost frame added for frame shift feat = feats[start_index:end_index:, :] feat = splice_feats(feat) + if len(eg.inputs) > 1: + repeat_ivector = torch.from_numpy(ivectors[i]).repeat(feat.shape[0], 1) + feat = torch.cat((torch.from_numpy(feat), repeat_ivector), dim=1).numpy() feat_list.append(feat) batched_feat = np.stack(feat_list, axis=0) @@ -182,7 +189,10 @@ def __call__(self, batch): # the first -2 is from extra left/right context # the second -2 is from lda feats splicing assert batched_feat.shape[1] == frames_per_sequence - 4 - assert batched_feat.shape[2] == feats.shape[-1] * 3 + if len(eg.inputs) > 1: + assert batched_feat.shape[2] == feats.shape[-1] * 3 + ivectors.shape[-1] + else: + assert batched_feat.shape[2] == feats.shape[-1] * 3 torch_feat = torch.from_numpy(batched_feat).float() feature_list.append(torch_feat) diff --git a/egs/aishell/s10/chain/feat_dataset.py b/egs/aishell/s10/chain/feat_dataset.py index 60972b44cd2..16564e59add 100755 --- a/egs/aishell/s10/chain/feat_dataset.py +++ b/egs/aishell/s10/chain/feat_dataset.py @@ -22,9 +22,10 @@ def get_feat_dataloader(feats_scp, model_left_context, model_right_context, + ivector_scp=None, batch_size=16, num_workers=10): - dataset = FeatDataset(feats_scp=feats_scp) + dataset = FeatDataset(feats_scp=feats_scp, ivector_scp=ivector_scp) collate_fn = FeatDatasetCollateFunc(model_left_context=model_left_context, model_right_context=model_right_context, @@ -55,21 +56,36 @@ def _add_model_left_right_context(x, left_context, right_context): class FeatDataset(Dataset): - def __init__(self, feats_scp): + def __init__(self, feats_scp, ivector_scp=None): assert os.path.isfile(feats_scp) self.feats_scp = feats_scp - # items is a list of [key, rxfilename] - items = list() + # items is a dict of {key: [key, rxfilename, ivec]} + items = dict() with open(feats_scp, 'r') as f: for line in f: split = line.split() assert len(split) == 2 - items.append(split) - - self.items = items + uttid, rxfilename =split + assert uttid not in items + items[uttid] = [uttid, rxfilename, None] + if ivector_scp: + self.ivector_scp = ivector_scp + expected_count = len(items) + n = 0 + with open(ivector_scp, 'r') as f: + for line in f: + uttid_rxfilename = line.split() + assert len(uttid_rxfilename) == 2 + uttid, rxfilename = uttid_rxfilename + assert uttid in items + items[uttid][-1] = rxfilename + n += 1 + assert n == expected_count + + self.items = list(items.values()) self.num_items = len(self.items) @@ -81,6 +97,8 @@ def __getitem__(self, i): def __str__(self): s = 'feats scp: {}\n'.format(self.feats_scp) + if self.ivector_scp: + s += 'ivector_scp scp: {}\n'.format(self.ivector_scp) s += 'num utt: {}\n'.format(self.num_items) return s @@ -105,11 +123,15 @@ def __call__(self, batch): ''' key_list = [] feat_list = [] + ivector_list = [] + ivector_len_list = [] output_len_list = [] for b in batch: - key, rxfilename = b + key, rxfilename, ivector_rxfilename = b key_list.append(key) feat = kaldi.read_mat(rxfilename).numpy() + if ivector_rxfilename: + ivector = kaldi.read_mat(ivector_rxfilename).numpy() # L // 10 * C feat_len = feat.shape[0] output_len = (feat_len + self.frame_subsampling_factor - 1) // self.frame_subsampling_factor @@ -120,10 +142,16 @@ def __call__(self, batch): feat = splice_feats(feat) feat_list.append(feat) # no need to sort the feat by length + ivector_list.append(ivector) + ivector_len_list.append(ivector.shape[0]) # the user should sort utterances by length offline # to avoid unnecessary padding padded_feat = pad_sequence( [torch.from_numpy(feat).float() for feat in feat_list], batch_first=True) - return key_list, padded_feat, output_len_list + if ivector_list: + padded_ivector = pad_sequence( + [torch.from_numpy(ivector).float() for ivector in ivector_list], + batch_first=True) + return key_list, padded_feat, output_len_list, padded_ivector, ivector_len_list diff --git a/egs/aishell/s10/chain/inference.py b/egs/aishell/s10/chain/inference.py index 6dca476b13a..b6fa3829438 100644 --- a/egs/aishell/s10/chain/inference.py +++ b/egs/aishell/s10/chain/inference.py @@ -38,6 +38,7 @@ def main(): model = get_chain_model( feat_dim=args.feat_dim, output_dim=args.output_dim, + ivector_dim=args.ivector_dim, lda_mat_filename=args.lda_mat_filename, hidden_dim=args.hidden_dim, bottleneck_dim=args.bottleneck_dim, @@ -64,15 +65,36 @@ def main(): dataloader = get_feat_dataloader( feats_scp=args.feats_scp, + ivector_scp=args.ivector_scp, model_left_context=args.model_left_context, model_right_context=args.model_right_context, - batch_size=32) - + batch_size=1, + num_workers=0) + subsampling_factor = 3 + subsampled_frames_per_chunk = args.frames_per_chunk // subsampling_factor for batch_idx, batch in enumerate(dataloader): - key_list, padded_feat, output_len_list = batch + key_list, padded_feat, output_len_list, padded_ivector, ivector_len_list = batch padded_feat = padded_feat.to(device) + if ivector_len_list: + padded_ivector = padded_ivector.to(device) with torch.no_grad(): - nnet_output, _ = model(padded_feat) + nnet_outputs = [] + input_num_frames = padded_feat.shape[1] + 2 \ + - args.model_left_context - args.model_right_context + for i in range(0, output_len_list[0], subsampled_frames_per_chunk): + # 418 -> [0, 17, 34, 51, 68, 85, 102, 119, 136] + first_output = i * subsampling_factor + last_output = min(input_num_frames, \ + first_output + (subsampled_frames_per_chunk-1) * subsampling_factor) + first_input = first_output + last_input = last_output + args.model_left_context + args.model_right_context + input_x = padded_feat[:, first_input:last_input+1, :] + ivector_index = (first_output + last_output) // 2 // args.ivector_period + input_ivector = padded_ivector[:, ivector_index, :] + feat = torch.cat((input_x, input_ivector.repeat((1, input_x.shape[1], 1))), dim=-1) + nnet_output_temp, _ = model(feat) + nnet_outputs.append(nnet_output_temp) + nnet_output = torch.cat(nnet_outputs, dim=1) num = len(key_list) for i in range(num): @@ -85,7 +107,7 @@ def main(): m = Matrix(m) writer.Write(key, m) - if batch_idx % 10 == 0: + if batch_idx % 100 == 0: logging.info('Processed batch {}/{} ({:.6f}%)'.format( batch_idx, len(dataloader), float(batch_idx) / len(dataloader) * 100)) diff --git a/egs/aishell/s10/chain/model.py b/egs/aishell/s10/chain/model.py index 6158dff22ec..62f43e30a82 100644 --- a/egs/aishell/s10/chain/model.py +++ b/egs/aishell/s10/chain/model.py @@ -17,6 +17,7 @@ def get_chain_model(feat_dim, output_dim, + ivector_dim, hidden_dim, bottleneck_dim, prefinal_bottleneck_dim, @@ -25,6 +26,7 @@ def get_chain_model(feat_dim, lda_mat_filename=None): model = ChainModel(feat_dim=feat_dim, output_dim=output_dim, + ivector_dim=ivector_dim, lda_mat_filename=lda_mat_filename, hidden_dim=hidden_dim, bottleneck_dim=bottleneck_dim, @@ -82,6 +84,7 @@ class ChainModel(nn.Module): def __init__(self, feat_dim, output_dim, + ivector_dim=0, lda_mat_filename=None, hidden_dim=1024, bottleneck_dim=128, @@ -97,8 +100,9 @@ def __init__(self, assert len(kernel_size_list) == len(subsampling_factor_list) num_layers = len(kernel_size_list) + input_dim = feat_dim * 3 + ivector_dim # tdnn1_affine requires [N, T, C] - self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3, + self.tdnn1_affine = nn.Linear(in_features=input_dim, out_features=hidden_dim) # tdnn1_batchnorm requires [N, C, T] @@ -142,11 +146,11 @@ def __init__(self, if lda_mat_filename: logging.info('Use LDA from {}'.format(lda_mat_filename)) self.lda_A, self.lda_b = load_lda_mat(lda_mat_filename) - assert feat_dim * 3 == self.lda_A.shape[0] + assert input_dim == self.lda_A.shape[0] self.has_LDA = True else: logging.info('replace LDA with BatchNorm') - self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3, + self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim, affine=False) self.has_LDA = False diff --git a/egs/aishell/s10/chain/options.py b/egs/aishell/s10/chain/options.py index 0c73274f83a..c92825f3067 100644 --- a/egs/aishell/s10/chain/options.py +++ b/egs/aishell/s10/chain/options.py @@ -27,6 +27,23 @@ def _set_inference_args(parser): dest='feats_scp', help='feats.scp filename, required for inference', type=str) + + parser.add_argument('--frames-per-chunk', + dest='frames_per_chunk', + help='frames per chunk', + type=int, + default=51) + + parser.add_argument('--ivector-scp', + dest='ivector_scp', + help='ivector.scp filename, required for ivector inference', + type=str) + + parser.add_argument('--ivector-period', + dest='ivector_period', + help='ivector period', + type=int, + default=10) parser.add_argument('--model-left-context', dest='model_left_context', @@ -228,10 +245,17 @@ def get_args(): parser.add_argument('--feat-dim', dest='feat_dim', - help='nn input dimension', + help='nn input 0 dimension', required=True, type=int) + parser.add_argument('--ivector-dim', + dest='ivector_dim', + help='nn input 1 dimension', + required=False, + default=0, + type=int) + parser.add_argument('--output-dim', dest='output_dim', help='nn output dimension', diff --git a/egs/aishell/s10/chain/train.py b/egs/aishell/s10/chain/train.py index b149bae5b32..8175b362df5 100644 --- a/egs/aishell/s10/chain/train.py +++ b/egs/aishell/s10/chain/train.py @@ -270,6 +270,7 @@ def process_job(learning_rate, local_rank=None): model = get_chain_model( feat_dim=args.feat_dim, output_dim=args.output_dim, + ivector_dim=args.ivector_dim, lda_mat_filename=args.lda_mat_filename, hidden_dim=args.hidden_dim, bottleneck_dim=args.bottleneck_dim, diff --git a/egs/aishell/s10/conf/online_cmvn.conf b/egs/aishell/s10/conf/online_cmvn.conf new file mode 100644 index 00000000000..591367e7ae9 --- /dev/null +++ b/egs/aishell/s10/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used when invoking online2-wav-nnet3-latgen-faster. diff --git a/egs/aishell/s10/local/run_chain.sh b/egs/aishell/s10/local/run_chain.sh index 3bc3380556b..e78fd11ca2b 100755 --- a/egs/aishell/s10/local/run_chain.sh +++ b/egs/aishell/s10/local/run_chain.sh @@ -14,13 +14,13 @@ gmm=tri3 nnet3_affix=_pybind tree_affix= tdnn_affix= - +online_cmvn=true +train_ivector=false . ./path.sh . ./cmd.sh . parse_options.sh - if [[ $stage -le 0 ]]; then echo "$0: preparing directory for low-resolution speed-perturbed data (for alignment)" utils/data/perturb_data_dir_speed_3way.sh data/$train_set data/${train_set}_sp @@ -64,20 +64,31 @@ if [[ $stage -le 3 ]]; then done fi +train_ivector_dir= +if [[ $train_ivector ]]; then + local/run_ivector_common.sh --stage $stage \ + --nj $nj \ + --train-set $train_set \ + --online-cmvn-iextractor $online_cmvn \ + --nnet3-affix "$nnet3_affix" + train_ivector_dir=exp/nnet3${nnet3_affix}/ivectors_${train_set}_sp_hires +fi +echo $train_ivector_dir + tree_dir=exp/chain${nnet3_affix}/tree_bi${tree_affix} lat_dir=exp/chain${nnet3_affix}/${gmm}_${train_set}_sp_lats dir=exp/chain${nnet3_affix}/tdnn${tdnn_affix}_sp train_data_dir=data/${train_set}_sp_hires lores_train_data_dir=data/${train_set}_sp -if [[ $stage -le 4 ]]; then +if [[ $stage -le 9 ]]; then for f in $gmm_dir/final.mdl $train_data_dir/feats.scp \ $lores_train_data_dir/feats.scp $ali_dir/ali.1.gz $gmm_dir/final.mdl; do [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1 done fi -if [[ $stage -le 5 ]]; then +if [[ $stage -le 10 ]]; then echo "$0: creating lang directory with one state per phone." # Create a version of the lang/ directory that has one state per phone in the # topo file. [note, it really has two states.. the first one is only repeated @@ -90,7 +101,7 @@ if [[ $stage -le 5 ]]; then steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >data/lang_chain/topo fi -if [[ $stage -le 6 ]]; then +if [[ $stage -le 11 ]]; then # Get the alignments as lattices (gives the chain training more freedom). # use the same num-jobs as the alignments steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" ${lores_train_data_dir} \ @@ -98,7 +109,7 @@ if [[ $stage -le 6 ]]; then rm $lat_dir/fsts.*.gz # save space fi -if [[ $stage -le 7 ]]; then +if [[ $stage -le 12 ]]; then # Build a tree using our new topology. We know we have alignments for the # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use # those. @@ -107,7 +118,7 @@ if [[ $stage -le 7 ]]; then --cmd "$train_cmd" 4000 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir fi -if [[ $stage -le 8 ]]; then +if [[ $stage -le 13 ]]; then echo "$0: creating phone language-model" $mkgraph_cmd $dir/log/make_phone_lm.log \ chain-est-phone-lm \ @@ -115,7 +126,7 @@ if [[ $stage -le 8 ]]; then $dir/phone_lm.fst || exit 1 fi -if [[ $stage -le 9 ]]; then +if [[ $stage -le 14 ]]; then echo "creating denominator FST" copy-transition-model $tree_dir/final.mdl $dir/0.trans_mdl cp $tree_dir/tree $dir @@ -145,12 +156,13 @@ log_level=info # valid values: debug, info, warning # false to save it as kaldi::Matrix save_nn_output_as_compressed=false -if [[ $stage -le 10 ]]; then +if [[ $stage -le 15 ]]; then echo "$0: generating egs" steps/nnet3/chain/get_egs.sh \ --alignment-subsampling-factor 3 \ --cmd "$train_cmd" \ - --online-cmvn true \ + --online-cmvn $online_cmvn \ + --online-ivector-dir $train_ivector_dir \ --frame-subsampling-factor 3 \ --frames-overlap-per-eg 0 \ --frames-per-eg $frames_per_eg \ @@ -169,7 +181,7 @@ if [[ $stage -le 10 ]]; then fi -if [[ $stage -le 11 ]]; then +if [[ $stage -le 16 ]]; then echo "$0: merging egs" mkdir -p $dir/merged_egs num_egs=$(ls -1 $dir/egs/cegs*.ark | wc -l) @@ -183,9 +195,11 @@ if [[ $stage -le 11 ]]; then fi feat_dim=$(cat $dir/egs/info/feat_dim) +ivector_dim=$(cat $dir/egs/info/ivector_dim) +ivector_period=$(cat $train_ivector_dir/ivector_period) output_dim=$(cat $dir/egs/info/num_pdfs) -if [[ $stage -le 12 ]]; then +if [[ $stage -le 17 ]]; then echo "$0: training..." mkdir -p $dir/train/tensorboard @@ -210,10 +224,10 @@ if [[ $stage -le 12 ]]; then use_ddp=true world_size=4 - use_multiple_machine=false + use_multiple_machine=true if $use_multiple_machine ; then # suppose you are using Sun GridEngine - cuda_train_cmd=$(echo "$cuda_train_cmd --gpu $world_size JOB=1:$world_size $dir/train/logs/job.JOB.log") + cuda_train_cmd=$(echo "$cuda_train_cmd --gpu 1 JOB=1:$world_size $dir/train/logs/job.JOB.log") else # TODO: for running with multiple GPUs on a single machine (SGE), # we should tell SGE how many GPUs we will use on that machine @@ -227,6 +241,7 @@ if [[ $stage -le 12 ]]; then --feat-dim $feat_dim \ --hidden-dim $hidden_dim \ --is-training true \ + --ivector-dim $ivector_dim \ --kernel-size-list "$kernel_size_list" \ --log-level $log_level \ --output-dim $output_dim \ @@ -248,7 +263,7 @@ if [[ $stage -le 12 ]]; then --train.xent-regularize 0.1 || exit 1; fi -if [[ $stage -le 13 ]]; then +if [[ $stage -le 18 ]]; then echo "inference: computing likelihood" for x in test dev; do mkdir -p $dir/inference/$x @@ -267,6 +282,9 @@ if [[ $stage -le 13 ]]; then --feats-scp data/${x}_hires/online_cmvn_feats.scp \ --hidden-dim $hidden_dim \ --is-training false \ + --ivector-dim $ivector_dim \ + --ivector-period $ivector_period \ + --ivector-scp exp/nnet3${nnet3_affix}/ivectors_${x}_hires/ivector_online.scp \ --log-level $log_level \ --kernel-size-list "$kernel_size_list" \ --prefinal-bottleneck-dim $prefinal_bottleneck_dim \ @@ -279,7 +297,7 @@ if [[ $stage -le 13 ]]; then done fi -if [[ $stage -le 14 ]]; then +if [[ $stage -le 19 ]]; then # Note: it might appear that this $lang directory is mismatched, and it is as # far as the 'topo' is concerned, but this script doesn't read the 'topo' from # the lang directory. @@ -288,7 +306,7 @@ if [[ $stage -le 14 ]]; then fi -if [[ $stage -le 15 ]]; then +if [[ $stage -le 20 ]]; then echo "decoding" for x in test dev; do if [[ ! -f $dir/inference/$x/nnet_output.scp ]]; then @@ -298,7 +316,7 @@ if [[ $stage -le 15 ]]; then fi echo "decoding $x" - ./local/decode.sh \ + ./local/decode.sh --cmd "$decode_cmd" \ --nj $nj \ $dir/graph \ $dir/0.trans_mdl \ @@ -307,7 +325,7 @@ if [[ $stage -le 15 ]]; then done fi -if [[ $stage -le 16 ]]; then +if [[ $stage -le 21 ]]; then echo "scoring" for x in test dev; do diff --git a/egs/aishell/s10/run.sh b/egs/aishell/s10/run.sh index 5bce308f968..fb6d6e12157 100755 --- a/egs/aishell/s10/run.sh +++ b/egs/aishell/s10/run.sh @@ -26,6 +26,7 @@ data_url=www.openslr.org/resources/33 nj=30 stage=13 +. utils/parse_options.sh || exit 1; if [[ $stage -le 0 ]]; then local/download_and_untar.sh $data $data_url data_aishell || exit 1 From c1216dcb27f48ec2244bcc54e4c3825d4d6a90f9 Mon Sep 17 00:00:00 2001 From: fanlu Date: Tue, 3 Mar 2020 00:05:20 +0800 Subject: [PATCH 2/4] add ivector training script --- egs/aishell/s10/local/run_ivector_common.sh | 102 ++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100755 egs/aishell/s10/local/run_ivector_common.sh diff --git a/egs/aishell/s10/local/run_ivector_common.sh b/egs/aishell/s10/local/run_ivector_common.sh new file mode 100755 index 00000000000..01692777b76 --- /dev/null +++ b/egs/aishell/s10/local/run_ivector_common.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +set -e -o pipefail + + +# This script is called from local/nnet3/run_tdnn.sh and local/chain/run_tdnn.sh (and may eventually +# be called by more scripts). It contains the common feature preparation and iVector-related parts +# of the script. See those scripts for examples of usage. + + +stage=0 +nj=30 + +train_set=train # you might set this to e.g. train. +gmm=tri3 # This specifies a GMM-dir from the features of the type you're training the system on; + # it should contain alignments for 'train_set'. +online_cmvn_iextractor=false + +num_threads_ubm=8 +nnet3_affix=_cleaned # affix for exp/nnet3 directory to put iVector stuff in + +. ./cmd.sh +. ./path.sh +. utils/parse_options.sh + + +if [ $stage -le 5 ]; then + echo "$0: computing a subset of data to train the diagonal UBM." + + mkdir -p exp/nnet3${nnet3_affix}/diag_ubm + temp_data_root=exp/nnet3${nnet3_affix}/diag_ubm + + # train a diagonal UBM using a subset of about a quarter of the data + num_utts_total=$(wc -l Date: Tue, 3 Mar 2020 09:57:12 +0800 Subject: [PATCH 3/4] support batch chunk --- egs/aishell/s10/chain/feat_dataset.py | 52 +++++++++++++++++++-------- egs/aishell/s10/chain/inference.py | 35 ++++++------------ egs/aishell/s10/local/run_chain.sh | 13 ++++--- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/egs/aishell/s10/chain/feat_dataset.py b/egs/aishell/s10/chain/feat_dataset.py index 16564e59add..73a37361f0b 100755 --- a/egs/aishell/s10/chain/feat_dataset.py +++ b/egs/aishell/s10/chain/feat_dataset.py @@ -4,7 +4,7 @@ # Apache 2.0 import os - +import math import numpy as np import torch @@ -22,14 +22,18 @@ def get_feat_dataloader(feats_scp, model_left_context, model_right_context, + frames_per_chunk=51, ivector_scp=None, + ivector_period=10, batch_size=16, num_workers=10): dataset = FeatDataset(feats_scp=feats_scp, ivector_scp=ivector_scp) collate_fn = FeatDatasetCollateFunc(model_left_context=model_left_context, model_right_context=model_right_context, - frame_subsampling_factor=3) + frame_subsampling_factor=3, + frames_per_chunk=frames_per_chunk, + ivector_period=ivector_period) dataloader = DataLoader(dataset, batch_size=batch_size, @@ -108,7 +112,9 @@ class FeatDatasetCollateFunc: def __init__(self, model_left_context, model_right_context, - frame_subsampling_factor=3): + frame_subsampling_factor=3, + frames_per_chunk=51, + ivector_period=10): ''' We need `frame_subsampling_factor` because we want to know the number of output frames of different waves in the same batch @@ -116,6 +122,8 @@ def __init__(self, self.model_left_context = model_left_context self.model_right_context = model_right_context self.frame_subsampling_factor = frame_subsampling_factor + self.frames_per_chunk = frames_per_chunk + self.ivector_period = ivector_period def __call__(self, batch): ''' @@ -126,6 +134,8 @@ def __call__(self, batch): ivector_list = [] ivector_len_list = [] output_len_list = [] + subsampled_frames_per_chunk = (self.frames_per_chunk // + self.frame_subsampling_factor) for b in batch: key, rxfilename, ivector_rxfilename = b key_list.append(key) @@ -140,18 +150,32 @@ def __call__(self, batch): feat = _add_model_left_right_context(feat, self.model_left_context, self.model_right_context) feat = splice_feats(feat) - feat_list.append(feat) - # no need to sort the feat by length - ivector_list.append(ivector) - ivector_len_list.append(ivector.shape[0]) - # the user should sort utterances by length offline - # to avoid unnecessary padding + # now we split feat to chunk, then we can do decode by chunk + input_num_frames = (feat.shape[0] + 2 + - self.model_left_context - self.model_right_context) + for i in range(0, output_len, subsampled_frames_per_chunk): + # input len:418 -> output len:140 -> output chunk:[0, 17, 34, 51, 68, 85, 102, 119, 136] + first_output = i * self.frame_subsampling_factor + last_output = min(input_num_frames, \ + first_output + (subsampled_frames_per_chunk-1) * self.frame_subsampling_factor) + first_input = first_output + last_input = last_output + self.model_left_context + self.model_right_context + input_x = feat[first_input:last_input+1, :] + if ivector_rxfilename: + ivector_index = (first_output + last_output) // 2 // self.ivector_period + input_ivector = ivector[ivector_index, :].reshape(1,-1) + feat_list.append(np.concatenate((input_x, + np.repeat(input_ivector, input_x.shape[0], axis=0)), + axis=-1)) + else: + feat_list.append(input_x) + padded_feat = pad_sequence( [torch.from_numpy(feat).float() for feat in feat_list], batch_first=True) - if ivector_list: - padded_ivector = pad_sequence( - [torch.from_numpy(ivector).float() for ivector in ivector_list], - batch_first=True) - return key_list, padded_feat, output_len_list, padded_ivector, ivector_len_list + + assert sum([math.ceil(l / subsampled_frames_per_chunk) for l in output_len_list]) \ + == padded_feat.shape[0] + + return key_list, padded_feat, output_len_list diff --git a/egs/aishell/s10/chain/inference.py b/egs/aishell/s10/chain/inference.py index b6fa3829438..aa598b90dc4 100644 --- a/egs/aishell/s10/chain/inference.py +++ b/egs/aishell/s10/chain/inference.py @@ -6,6 +6,7 @@ import logging import os import sys +import math import torch from torch.utils.dlpack import to_dlpack @@ -68,46 +69,32 @@ def main(): ivector_scp=args.ivector_scp, model_left_context=args.model_left_context, model_right_context=args.model_right_context, - batch_size=1, - num_workers=0) + batch_size=32, + num_workers=10) subsampling_factor = 3 subsampled_frames_per_chunk = args.frames_per_chunk // subsampling_factor for batch_idx, batch in enumerate(dataloader): - key_list, padded_feat, output_len_list, padded_ivector, ivector_len_list = batch + key_list, padded_feat, output_len_list = batch padded_feat = padded_feat.to(device) - if ivector_len_list: - padded_ivector = padded_ivector.to(device) with torch.no_grad(): - nnet_outputs = [] - input_num_frames = padded_feat.shape[1] + 2 \ - - args.model_left_context - args.model_right_context - for i in range(0, output_len_list[0], subsampled_frames_per_chunk): - # 418 -> [0, 17, 34, 51, 68, 85, 102, 119, 136] - first_output = i * subsampling_factor - last_output = min(input_num_frames, \ - first_output + (subsampled_frames_per_chunk-1) * subsampling_factor) - first_input = first_output - last_input = last_output + args.model_left_context + args.model_right_context - input_x = padded_feat[:, first_input:last_input+1, :] - ivector_index = (first_output + last_output) // 2 // args.ivector_period - input_ivector = padded_ivector[:, ivector_index, :] - feat = torch.cat((input_x, input_ivector.repeat((1, input_x.shape[1], 1))), dim=-1) - nnet_output_temp, _ = model(feat) - nnet_outputs.append(nnet_output_temp) - nnet_output = torch.cat(nnet_outputs, dim=1) + nnet_output, _ = model(padded_feat) num = len(key_list) + first = 0 for i in range(num): key = key_list[i] output_len = output_len_list[i] - value = nnet_output[i, :output_len, :] + target_len = math.ceil(output_len / subsampled_frames_per_chunk) + result = nnet_output[first:first + target_len, :, :].split(1, 0) + value = torch.cat(result, dim=1)[0, :output_len, :] value = value.cpu() + first += target_len m = kaldi.SubMatrixFromDLPack(to_dlpack(value)) m = Matrix(m) writer.Write(key, m) - if batch_idx % 100 == 0: + if batch_idx % 10 == 0: logging.info('Processed batch {}/{} ({:.6f}%)'.format( batch_idx, len(dataloader), float(batch_idx) / len(dataloader) * 100)) diff --git a/egs/aishell/s10/local/run_chain.sh b/egs/aishell/s10/local/run_chain.sh index e78fd11ca2b..1ed0257fcd7 100755 --- a/egs/aishell/s10/local/run_chain.sh +++ b/egs/aishell/s10/local/run_chain.sh @@ -16,6 +16,8 @@ tree_affix= tdnn_affix= online_cmvn=true train_ivector=false +num_epochs=6 + . ./path.sh . ./cmd.sh @@ -204,8 +206,8 @@ if [[ $stage -le 17 ]]; then mkdir -p $dir/train/tensorboard train_checkpoint= - if [[ -f $dir/train/best_model.pt ]]; then - train_checkpoint=$dir/train/best_model.pt + if [[ -f $dir/best_model.pt ]]; then + train_checkpoint=$dir/best_model.pt fi INIT_FILE=$dir/ddp_init @@ -215,7 +217,7 @@ if [[ $stage -le 17 ]]; then # init_method=tcp://127.0.0.1:7275 echo "$0: init method is $init_method" - num_epochs=6 + num_epochs=$num_epochs lr=1e-3 # use_ddp = false: training model with one GPU @@ -274,7 +276,8 @@ if [[ $stage -le 18 ]]; then scp:data/${x}_hires/feats.scp ark,scp:data/${x}_hires/data/online_cmvn_feats.ark,data/${x}_hires/online_cmvn_feats.scp best_epoch=$(cat $dir/best-epoch-info | grep 'best epoch' | awk '{print $NF}') inference_checkpoint=$dir/epoch-${best_epoch}.pt - $cuda_inference_cmd $dir/inference/logs/${x}.log python3 ./chain/inference.py \ + $cuda_inference_cmd $dir/inference/logs/${x}.log \ + python3 ./chain/inference.py \ --bottleneck-dim $bottleneck_dim \ --checkpoint $inference_checkpoint \ --dir $dir/inference/$x \ @@ -316,7 +319,7 @@ if [[ $stage -le 20 ]]; then fi echo "decoding $x" - ./local/decode.sh --cmd "$decode_cmd" \ + ./local/decode.sh \ --nj $nj \ $dir/graph \ $dir/0.trans_mdl \ From 6331827eedf51177f4f8e44a00178caf19d6105d Mon Sep 17 00:00:00 2001 From: fanlu Date: Tue, 3 Mar 2020 10:37:56 +0800 Subject: [PATCH 4/4] format and some comment fix --- egs/aishell/s10/chain/egs_dataset.py | 37 +++++++++++---------- egs/aishell/s10/chain/feat_dataset.py | 36 +++++++++++--------- egs/aishell/s10/local/run_ivector_common.sh | 2 +- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/egs/aishell/s10/chain/egs_dataset.py b/egs/aishell/s10/chain/egs_dataset.py index 97eae03b992..fc31f9696b3 100755 --- a/egs/aishell/s10/chain/egs_dataset.py +++ b/egs/aishell/s10/chain/egs_dataset.py @@ -41,15 +41,16 @@ def get_egs_dataloader(egs_dir_or_scp, sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=local_rank, shuffle=True) dataloader = DataLoader(dataset, - batch_size=batch_size, - collate_fn=collate_fn, - sampler=sampler) + batch_size=batch_size, + collate_fn=collate_fn, + sampler=sampler) else: - base_sampler = torch.utils.data.RandomSampler(dataset) - sampler = torch.utils.data.BatchSampler(base_sampler, batch_size, False) - dataloader = DataLoader(dataset, - batch_sampler=sampler, - collate_fn=collate_fn) + base_sampler = torch.utils.data.RandomSampler(dataset) + sampler = torch.utils.data.BatchSampler( + base_sampler, batch_size, False) + dataloader = DataLoader(dataset, + batch_sampler=sampler, + collate_fn=collate_fn) return dataloader @@ -146,11 +147,10 @@ def __call__(self, batch): batch_size = supervision.num_sequences - frames_per_sequence = (supervision.frames_per_sequence * \ - self.frame_subsampling_factor) + \ - self.egs_left_context + self.egs_right_context + frames_per_sequence = (supervision.frames_per_sequence * + self.frame_subsampling_factor) + \ + self.egs_left_context + self.egs_right_context - # TODO(fangjun): support ivector assert eg.inputs[0].name == 'input' _feats = kaldi.FloatMatrix() @@ -178,8 +178,10 @@ def __call__(self, batch): feat = feats[start_index:end_index:, :] feat = splice_feats(feat) if len(eg.inputs) > 1: - repeat_ivector = torch.from_numpy(ivectors[i]).repeat(feat.shape[0], 1) - feat = torch.cat((torch.from_numpy(feat), repeat_ivector), dim=1).numpy() + repeat_ivector = torch.from_numpy( + ivectors[i]).repeat(feat.shape[0], 1) + feat = torch.cat( + (torch.from_numpy(feat), repeat_ivector), dim=1).numpy() feat_list.append(feat) batched_feat = np.stack(feat_list, axis=0) @@ -190,7 +192,8 @@ def __call__(self, batch): # the second -2 is from lda feats splicing assert batched_feat.shape[1] == frames_per_sequence - 4 if len(eg.inputs) > 1: - assert batched_feat.shape[2] == feats.shape[-1] * 3 + ivectors.shape[-1] + assert batched_feat.shape[2] == feats.shape[-1] * \ + 3 + ivectors.shape[-1] else: assert batched_feat.shape[2] == feats.shape[-1] * 3 @@ -232,8 +235,8 @@ def _test_nnet_chain_example_dataset(): for b in dataloader: key_list, feature_list, supervision_list = b assert feature_list[0].shape == (128, 204, 129) \ - or feature_list[0].shape == (128, 144, 129) \ - or feature_list[0].shape == (128, 165, 129) + or feature_list[0].shape == (128, 144, 129) \ + or feature_list[0].shape == (128, 165, 129) assert supervision_list[0].weight == 1 supervision_list[0].num_sequences == 128 # minibach size is 128 diff --git a/egs/aishell/s10/chain/feat_dataset.py b/egs/aishell/s10/chain/feat_dataset.py index 73a37361f0b..72cda21af55 100755 --- a/egs/aishell/s10/chain/feat_dataset.py +++ b/egs/aishell/s10/chain/feat_dataset.py @@ -62,19 +62,23 @@ class FeatDataset(Dataset): def __init__(self, feats_scp, ivector_scp=None): assert os.path.isfile(feats_scp) + if ivector_scp: + assert os.path.isfile(ivector_scp) self.feats_scp = feats_scp - # items is a dict of {key: [key, rxfilename, ivec]} + # items is a dict of [uttid, feat_rxfilename, None] + # or [uttid, feat_rxfilename, ivector_rxfilename] if ivector_scp is not None items = dict() with open(feats_scp, 'r') as f: for line in f: split = line.split() assert len(split) == 2 - uttid, rxfilename =split + uttid, rxfilename = split assert uttid not in items items[uttid] = [uttid, rxfilename, None] + self.ivector_scp = None if ivector_scp: self.ivector_scp = ivector_scp expected_count = len(items) @@ -134,14 +138,15 @@ def __call__(self, batch): ivector_list = [] ivector_len_list = [] output_len_list = [] - subsampled_frames_per_chunk = (self.frames_per_chunk // - self.frame_subsampling_factor) + subsampled_frames_per_chunk = (self.frames_per_chunk // + self.frame_subsampling_factor) for b in batch: key, rxfilename, ivector_rxfilename = b key_list.append(key) feat = kaldi.read_mat(rxfilename).numpy() if ivector_rxfilename: - ivector = kaldi.read_mat(ivector_rxfilename).numpy() # L // 10 * C + ivector = kaldi.read_mat( + ivector_rxfilename).numpy() # L // 10 * C feat_len = feat.shape[0] output_len = (feat_len + self.frame_subsampling_factor - 1) // self.frame_subsampling_factor @@ -152,30 +157,31 @@ def __call__(self, batch): feat = splice_feats(feat) # now we split feat to chunk, then we can do decode by chunk - input_num_frames = (feat.shape[0] + 2 + input_num_frames = (feat.shape[0] + 2 - self.model_left_context - self.model_right_context) for i in range(0, output_len, subsampled_frames_per_chunk): # input len:418 -> output len:140 -> output chunk:[0, 17, 34, 51, 68, 85, 102, 119, 136] first_output = i * self.frame_subsampling_factor - last_output = min(input_num_frames, \ - first_output + (subsampled_frames_per_chunk-1) * self.frame_subsampling_factor) + last_output = min(input_num_frames, + first_output + (subsampled_frames_per_chunk-1) * self.frame_subsampling_factor) first_input = first_output last_input = last_output + self.model_left_context + self.model_right_context input_x = feat[first_input:last_input+1, :] if ivector_rxfilename: - ivector_index = (first_output + last_output) // 2 // self.ivector_period - input_ivector = ivector[ivector_index, :].reshape(1,-1) - feat_list.append(np.concatenate((input_x, - np.repeat(input_ivector, input_x.shape[0], axis=0)), - axis=-1)) + ivector_index = ( + first_output + last_output) // 2 // self.ivector_period + input_ivector = ivector[ivector_index, :].reshape(1, -1) + feat_list.append(np.concatenate((input_x, + np.repeat(input_ivector, input_x.shape[0], axis=0)), + axis=-1)) else: feat_list.append(input_x) padded_feat = pad_sequence( [torch.from_numpy(feat).float() for feat in feat_list], batch_first=True) - + assert sum([math.ceil(l / subsampled_frames_per_chunk) for l in output_len_list]) \ - == padded_feat.shape[0] + == padded_feat.shape[0] return key_list, padded_feat, output_len_list diff --git a/egs/aishell/s10/local/run_ivector_common.sh b/egs/aishell/s10/local/run_ivector_common.sh index 01692777b76..55cc6f63631 100755 --- a/egs/aishell/s10/local/run_ivector_common.sh +++ b/egs/aishell/s10/local/run_ivector_common.sh @@ -99,4 +99,4 @@ if [[ $stage -le 8 ]]; then done fi -exit 0; \ No newline at end of file +exit 0;