diff --git a/egs/aishell/s10/chain/common.py b/egs/aishell/s10/chain/common.py index 584b83b2500..4c7ae7bc7e7 100644 --- a/egs/aishell/s10/chain/common.py +++ b/egs/aishell/s10/chain/common.py @@ -43,7 +43,19 @@ def load_checkpoint(filename, model): for k in keys: assert k in checkpoint - model.load_state_dict(checkpoint['state_dict']) + if not list(model.state_dict().keys())[0].startswith('module.') \ + and list(checkpoint['state_dict'])[0].startswith('module.'): + # the checkpoint was saved by DDP + logging.info('load checkpoint from DDP') + dst_state_dict = model.state_dict() + src_state_dict = checkpoint['state_dict'] + for key in dst_state_dict.keys(): + src_key = '{}.{}'.format('module', key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict) + else: + model.load_state_dict(checkpoint['state_dict']) epoch = checkpoint['epoch'] learning_rate = checkpoint['learning_rate'] @@ -52,7 +64,10 @@ def load_checkpoint(filename, model): return epoch, learning_rate, objf -def save_checkpoint(filename, model, epoch, learning_rate, objf): +def save_checkpoint(filename, model, epoch, learning_rate, objf, local_rank=0): + if local_rank != 0: + return + logging.info('Save checkpoint to {filename}: epoch={epoch}, ' 'learning_rate={learning_rate}, objf={objf}'.format( filename=filename, @@ -68,8 +83,17 @@ def save_checkpoint(filename, model, epoch, learning_rate, objf): torch.save(checkpoint, filename) -def save_training_info(filename, model_path, current_epoch, learning_rate, objf, - best_objf, best_epoch): +def save_training_info(filename, + model_path, + current_epoch, + learning_rate, + objf, + best_objf, + best_epoch, + local_rank=0): + if local_rank != 0: + return + with open(filename, 'w') as f: f.write('model_path: {}\n'.format(model_path)) f.write('epoch: {}\n'.format(current_epoch)) diff --git a/egs/aishell/s10/chain/ddp_train.py b/egs/aishell/s10/chain/ddp_train.py new file mode 100644 index 00000000000..76c9f0eb103 --- /dev/null +++ b/egs/aishell/s10/chain/ddp_train.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python3 + +# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 + +import logging +import os +import sys +import warnings + +# disable warnings when loading tensorboard +warnings.simplefilter(action='ignore', category=FutureWarning) + +import numpy as np +import torch +import torch.distributed as dist +import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_value_ +from torch.utils.tensorboard import SummaryWriter + +import kaldi +import kaldi_pybind.chain as chain +import kaldi_pybind.fst as fst + +from chain_loss import KaldiChainObjfFunction +from common import load_checkpoint +from common import save_checkpoint +from common import save_training_info +from common import setup_logger +from egs_dataset import get_egs_dataloader +from model import get_chain_model +from options import get_args + + +def get_validation_objf(dataloader, model, device, criterion, opts, den_graph): + total_objf = 0. + total_weight = 0. + total_frames = 0. # for display only + + model.eval() + + for batch_idx, batch in enumerate(dataloader): + key_list, feature_list, supervision_list = batch + + assert len(key_list) == len(feature_list) == len(supervision_list) + batch_size = len(key_list) + + for n in range(batch_size): + feats = feature_list[n] + assert feats.ndim == 3 + + # at this point, feats is [N, T, C] + feats = feats.to(device) + + with torch.no_grad(): + nnet_output, xent_output = model(feats) + + # at this point, nnet_output is: [N, T, C] + # refer to kaldi/src/chain/chain-training.h + # the output should be organized as + # [all sequences for frame 0] + # [all sequences for frame 1] + # [etc.] + nnet_output = nnet_output.permute(1, 0, 2) + # at this point, nnet_output is: [T, N, C] + nnet_output = nnet_output.contiguous().view(-1, + nnet_output.shape[-1]) + + # at this point, xent_output is: [N, T, C] + xent_output = xent_output.permute(1, 0, 2) + # at this point, xent_output is: [T, N, C] + xent_output = xent_output.contiguous().view(-1, + xent_output.shape[-1]) + objf_l2_term_weight = criterion(opts, den_graph, + supervision_list[n], nnet_output, + xent_output) + objf = objf_l2_term_weight[0] + + objf_l2_term_weight = objf_l2_term_weight.cpu() + + total_objf += objf_l2_term_weight[0].item() + total_weight += objf_l2_term_weight[2].item() + + num_frames = nnet_output.shape[0] + total_frames += num_frames + + return total_objf, total_weight, total_frames + + +def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, + criterion, current_epoch, opts, den_graph, tf_writer): + model.train() + + total_objf = 0. + total_weight = 0. + total_frames = 0. # for display only + + for batch_idx, batch in enumerate(dataloader): + key_list, feature_list, supervision_list = batch + assert len(key_list) == len(feature_list) == len(supervision_list) + batch_size = len(key_list) + for n in range(batch_size): + feats = feature_list[n] + assert feats.ndim == 3 + + # at this point, feats is [N, T, C] + feats = feats.to(device) + nnet_output, xent_output = model(feats) + + # at this point, nnet_output is: [N, T, C] + # refer to kaldi/src/chain/chain-training.h + # the output should be organized as + # [all sequences for frame 0] + # [all sequences for frame 1] + # [etc.] + nnet_output = nnet_output.permute(1, 0, 2) + # at this point, nnet_output is: [T, N, C] + nnet_output = nnet_output.contiguous().view(-1, + nnet_output.shape[-1]) + + # at this point, xent_output is: [N, T, C] + xent_output = xent_output.permute(1, 0, 2) + # at this point, xent_output is: [T, N, C] + xent_output = xent_output.contiguous().view(-1, + xent_output.shape[-1]) + objf_l2_term_weight = criterion(opts, den_graph, + supervision_list[n], nnet_output, + xent_output) + objf = objf_l2_term_weight[0] + optimizer.zero_grad() + objf.backward() + + clip_grad_value_(model.parameters(), 5.0) + + optimizer.step() + + objf_l2_term_weight = objf_l2_term_weight.detach().cpu() + + total_objf += objf_l2_term_weight[0].item() + total_weight += objf_l2_term_weight[2].item() + num_frames = nnet_output.shape[0] + total_frames += num_frames + + if batch_idx % 100 == 0: + logging.info( + 'Device ({}) processing {}/{}({:.6f}%) global average objf: {:.6f} over {} ' + 'frames, current batch average objf: {:.6f} over {} frames, epoch {}' + .format( + device.index, batch_idx, len(dataloader), + float(batch_idx) / len(dataloader) * 100, + total_objf / total_weight, total_frames, + objf_l2_term_weight[0].item() / + objf_l2_term_weight[2].item(), num_frames, current_epoch)) + + if valid_dataloader and batch_idx % 1000 == 0: + total_valid_objf, total_valid_weight, total_valid_frames = get_validation_objf( + dataloader=valid_dataloader, + model=model, + device=device, + criterion=criterion, + opts=opts, + den_graph=den_graph) + + model.train() + + logging.info( + 'Validation average objf: {:.6f} over {} frames'.format( + total_valid_objf / total_valid_weight, total_valid_frames)) + + tf_writer.add_scalar('train/global_valid_average_objf', + total_valid_objf / total_valid_weight, + batch_idx + current_epoch * len(dataloader)) + + if device.index == 0 and batch_idx % 100 == 0: + tf_writer.add_scalar('train/global_average_objf', + total_objf / total_weight, + batch_idx + current_epoch * len(dataloader)) + tf_writer.add_scalar( + 'train/current_batch_average_objf', + objf_l2_term_weight[0].item() / objf_l2_term_weight[2].item(), + batch_idx + current_epoch * len(dataloader)) + + state_dict = model.state_dict() + for key, value in state_dict.items(): + # skip batchnorm parameters + if value.dtype != torch.float32: + continue + if 'running_mean' in key or 'running_var' in key: + continue + + with torch.no_grad(): + frobenius_norm = torch.norm(value, p='fro') + + tf_writer.add_scalar( + 'train/parameters/{}'.format(key), frobenius_norm, + batch_idx + current_epoch * len(dataloader)) + + return total_objf / total_weight + + +def main(): + args = get_args() + setup_logger('{}/log-train-device-{}'.format(args.dir, args.device_id), + args.log_level) + logging.info(' '.join(sys.argv)) + + if torch.cuda.is_available() == False: + logging.error('No GPU detected!') + sys.exit(-1) + + dist.init_process_group('nccl', + rank=args.device_id, + world_size=args.world_size) + + # WARNING(fangjun): we have to select GPU at the very + # beginning; otherwise you will get trouble later + kaldi.SelectGpuDevice(device_id=args.device_id) + kaldi.CuDeviceAllowMultithreading() + + device = torch.device('cuda', args.device_id) + + den_fst = fst.StdVectorFst.Read(args.den_fst_filename) + + opts = chain.ChainTrainingOptions() + opts.l2_regularize = args.l2_regularize + opts.xent_regularize = args.xent_regularize + opts.leaky_hmm_coefficient = args.leaky_hmm_coefficient + + den_graph = chain.DenominatorGraph(fst=den_fst, num_pdfs=args.output_dim) + + model = get_chain_model( + feat_dim=args.feat_dim, + output_dim=args.output_dim, + lda_mat_filename=args.lda_mat_filename, + hidden_dim=args.hidden_dim, + bottleneck_dim=args.bottleneck_dim, + prefinal_bottleneck_dim=args.prefinal_bottleneck_dim, + kernel_size_list=args.kernel_size_list, + subsampling_factor_list=args.subsampling_factor_list) + + start_epoch = 0 + num_epochs = args.num_epochs + learning_rate = args.learning_rate + best_objf = -100000 + + if args.checkpoint: + start_epoch, learning_rate, best_objf = load_checkpoint( + args.checkpoint, model) + logging.info( + 'Device ({device_id}) loaded from checkpoint: start epoch {start_epoch}, ' + 'learning rate {learning_rate}, best objf {best_objf}'.format( + device_id=args.device_id, + start_epoch=start_epoch, + learning_rate=learning_rate, + best_objf=best_objf)) + + model.to(device) + + model = DDP(model, device_ids=[args.device_id]) + + dataloader = get_egs_dataloader(egs_dir_or_scp=args.cegs_dir, + egs_left_context=args.egs_left_context, + egs_right_context=args.egs_right_context, + shuffle=False, + world_size=args.world_size, + local_rank=args.device_id) + + if args.device_id == 0: + valid_dataloader = get_egs_dataloader( + egs_dir_or_scp=args.valid_cegs_scp, + egs_left_context=args.egs_left_context, + egs_right_context=args.egs_right_context, + shuffle=False) + else: + valid_dataloader = None + + optimizer = optim.Adam(model.parameters(), + lr=learning_rate, + weight_decay=5e-4) + + criterion = KaldiChainObjfFunction.apply + + if args.device_id == 0: + tf_writer = SummaryWriter(log_dir='{}/tensorboard'.format(args.dir)) + else: + tf_writer = None + + best_epoch = start_epoch + best_model_path = os.path.join(args.dir, 'best_model.pt') + best_epoch_info_filename = os.path.join(args.dir, 'best-epoch-info') + + dist.barrier() + + try: + for epoch in range(start_epoch, args.num_epochs): + learning_rate = 1e-3 * pow(0.4, epoch) + for param_group in optimizer.param_groups: + param_group['lr'] = learning_rate + + logging.info('epoch {}, learning rate {}'.format( + epoch, learning_rate)) + + if tf_writer: + tf_writer.add_scalar('learning_rate', learning_rate, epoch) + + objf = train_one_epoch(dataloader=dataloader, + valid_dataloader=valid_dataloader, + model=model, + device=device, + optimizer=optimizer, + criterion=criterion, + current_epoch=epoch, + opts=opts, + den_graph=den_graph, + tf_writer=tf_writer) + + if best_objf is None: + best_objf = objf + best_epoch = epoch + + # the higher, the better + if objf > best_objf: + best_objf = objf + best_epoch = epoch + save_checkpoint(filename=best_model_path, + model=model, + epoch=epoch, + learning_rate=learning_rate, + objf=objf, + local_rank=args.device_id) + save_training_info(filename=best_epoch_info_filename, + model_path=best_model_path, + current_epoch=epoch, + learning_rate=learning_rate, + objf=best_objf, + best_objf=best_objf, + best_epoch=best_epoch, + local_rank=args.device_id) + + # we always save the model for every epoch + model_path = os.path.join(args.dir, 'epoch-{}.pt'.format(epoch)) + save_checkpoint(filename=model_path, + model=model, + epoch=epoch, + learning_rate=learning_rate, + objf=objf, + local_rank=args.device_id) + + epoch_info_filename = os.path.join(args.dir, + 'epoch-{}-info'.format(epoch)) + save_training_info(filename=epoch_info_filename, + model_path=model_path, + current_epoch=epoch, + learning_rate=learning_rate, + objf=objf, + best_objf=best_objf, + best_epoch=best_epoch, + local_rank=args.device_id) + + except KeyboardInterrupt: + # save the model when ctrl-c is pressed + model_path = os.path.join(args.dir, + 'epoch-{}-interrupted.pt'.format(epoch)) + # use a very small objf for interrupted model + objf = -100000 + save_checkpoint(model_path, + model=model, + epoch=epoch, + learning_rate=learning_rate, + objf=objf, + local_rank=args.device_id) + + epoch_info_filename = os.path.join( + args.dir, 'epoch-{}-interrupted-info'.format(epoch)) + save_training_info(filename=epoch_info_filename, + model_path=model_path, + current_epoch=epoch, + learning_rate=learning_rate, + objf=objf, + best_objf=best_objf, + best_epoch=best_epoch, + local_rank=args.device_id) + + if tf_writer: + tf_writer.close() + logging.warning('Done') + + +if __name__ == '__main__': + torch.manual_seed(20191227) + main() diff --git a/egs/aishell/s10/chain/egs_dataset.py b/egs/aishell/s10/chain/egs_dataset.py index e081880a3c4..96dc8d1e469 100755 --- a/egs/aishell/s10/chain/egs_dataset.py +++ b/egs/aishell/s10/chain/egs_dataset.py @@ -21,8 +21,12 @@ def get_egs_dataloader(egs_dir_or_scp, egs_left_context, egs_right_context, - shuffle=True): - + shuffle=True, + world_size=None, + local_rank=None): + ''' + world_size and local_rank is for DistributedDataParallel training. + ''' dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir_or_scp) frame_subsampling_factor = 3 @@ -34,11 +38,22 @@ def get_egs_dataloader(egs_dir_or_scp, egs_right_context=egs_right_context, frame_subsampling_factor=frame_subsampling_factor) + if world_size: + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=local_rank) + # sampler and shuffle are mutually exclusive; + # it will raise an exception if you set both + shuffle = False + + else: + sampler = None + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, - collate_fn=collate_fn) + collate_fn=collate_fn, + sampler=sampler) return dataloader diff --git a/egs/aishell/s10/chain/model.py b/egs/aishell/s10/chain/model.py index ff27b2d9cd8..6158dff22ec 100644 --- a/egs/aishell/s10/chain/model.py +++ b/egs/aishell/s10/chain/model.py @@ -66,6 +66,16 @@ def get_chain_model(feat_dim, ''' +def constrain_orthonormal_hook(model, unused_x): + if model.training == False: + return + + with torch.no_grad(): + for m in model.modules(): + if hasattr(m, 'constrain_orthonormal'): + m.constrain_orthonormal() + + # Create a network like the above one class ChainModel(nn.Module): @@ -76,7 +86,7 @@ def __init__(self, hidden_dim=1024, bottleneck_dim=128, prefinal_bottleneck_dim=256, - kernel_size_list=[2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2], + kernel_size_list=[3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3], subsampling_factor_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1], frame_subsampling_factor=3): super().__init__() @@ -140,6 +150,8 @@ def __init__(self, affine=False) self.has_LDA = False + self.register_forward_pre_hook(constrain_orthonormal_hook) + def forward(self, x): # input x is of shape: [batch_size, seq_len, feat_dim] = [N, T, C] assert x.ndim == 3 @@ -208,28 +220,19 @@ def forward(self, x): return nnet_output, xent_output - def constrain_orthonormal(self): - for i in range(len(self.tdnnfs)): - self.tdnnfs[i].constrain_orthonormal() - - self.prefinal_l.constrain_orthonormal() - self.prefinal_chain.constrain_orthonormal() - self.prefinal_xent.constrain_orthonormal() - if __name__ == '__main__': logging.basicConfig( level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) - feat_dim = 43 - output_dim = 4344 + feat_dim = 40 + output_dim = 3456 model = ChainModel(feat_dim=feat_dim, output_dim=output_dim) - logging.info(model) + # logging.info(model) N = 1 T = 150 + 27 + 27 C = feat_dim * 3 x = torch.arange(N * T * C).reshape(N, T, C).float() nnet_output, xent_output = model(x) print(x.shape, nnet_output.shape, xent_output.shape) - model.constrain_orthonormal() diff --git a/egs/aishell/s10/chain/options.py b/egs/aishell/s10/chain/options.py index 4cc5fa59168..8c3f1c74383 100644 --- a/egs/aishell/s10/chain/options.py +++ b/egs/aishell/s10/chain/options.py @@ -99,6 +99,21 @@ def _set_training_args(parser): help='leaky hmm coefficient', type=float) + # PyTorch DistributedDataParallel (ddp) parameters + parser.add_argument( + '--train.use-ddp', + dest='use_ddp', + help="true to use PyTorch's built-in DistributedDataParallel trainer", + type=_str2bool) + + # note that we use device id as local rank. + + parser.add_argument('--train.ddp.world-size', + dest='world_size', + help='world size in ddp', + default=1, + type=int) + def _check_training_args(args): assert os.path.isdir(args.cegs_dir) @@ -118,6 +133,9 @@ def _check_training_args(args): if args.checkpoint: assert os.path.exists(args.checkpoint) + if args.use_ddp: + assert args.world_size >= 1 + def _check_inference_args(args): assert args.checkpoint is not None diff --git a/egs/aishell/s10/chain/tdnnf_layer.py b/egs/aishell/s10/chain/tdnnf_layer.py index bcc5a5f6e56..6b98dd12b03 100644 --- a/egs/aishell/s10/chain/tdnnf_layer.py +++ b/egs/aishell/s10/chain/tdnnf_layer.py @@ -136,9 +136,6 @@ def forward(self, x): return x - def constrain_orthonormal(self): - self.linear.constrain_orthonormal() - class FactorizedTDNN(nn.Module): ''' @@ -210,9 +207,6 @@ def forward(self, x): x = self.bypass_scale * input_x[:, :, ::self.s] + x return x - def constrain_orthonormal(self): - self.linear.constrain_orthonormal() - def _test_constrain_orthonormal(): @@ -251,11 +245,17 @@ def compute_loss(M): kernel_size=3, subsampling_factor=1) loss = [] - model.constrain_orthonormal() + + for m in model.modules(): + if hasattr(m, 'constrain_orthonormal'): + m.constrain_orthonormal() + loss.append( compute_loss(model.linear.conv.state_dict()['weight'].reshape(128, -1))) for i in range(5): - model.constrain_orthonormal() + for m in model.modules(): + if hasattr(m, 'constrain_orthonormal'): + m.constrain_orthonormal() loss.append( compute_loss(model.linear.conv.state_dict()['weight'].reshape( 128, -1))) diff --git a/egs/aishell/s10/chain/train.py b/egs/aishell/s10/chain/train.py index 84bcb4f12b8..0c7cd1021aa 100644 --- a/egs/aishell/s10/chain/train.py +++ b/egs/aishell/s10/chain/train.py @@ -140,10 +140,6 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, num_frames = nnet_output.shape[0] total_frames += num_frames - if np.random.choice(4) == 0: - with torch.no_grad(): - model.constrain_orthonormal() - if batch_idx % 100 == 0: logging.info( 'Process {}/{}({:.6f}%) global average objf: {:.6f} over {} ' @@ -274,9 +270,11 @@ def main(): best_epoch = start_epoch best_model_path = os.path.join(args.dir, 'best_model.pt') best_epoch_info_filename = os.path.join(args.dir, 'best-epoch-info') + + lr = learning_rate try: for epoch in range(start_epoch, args.num_epochs): - learning_rate = 1e-3 * pow(0.4, epoch) + learning_rate = lr * pow(0.4, epoch) for param_group in optimizer.param_groups: param_group['lr'] = learning_rate diff --git a/egs/aishell/s10/local/run_chain.sh b/egs/aishell/s10/local/run_chain.sh index 25b944fde3d..7a498e365e8 100755 --- a/egs/aishell/s10/local/run_chain.sh +++ b/egs/aishell/s10/local/run_chain.sh @@ -5,11 +5,10 @@ set -e -stage=10 +stage=0 -# GPU device id to use (count from 0). -# you can also set `CUDA_VISIBLE_DEVICES` and set `device_id=0` -device_id=3 +export CUDA_VISIBLE_DEVICES="0,3" +device_id=0 nj=10 @@ -199,30 +198,71 @@ if [[ $stage -le 13 ]]; then train_checkpoint=./exp/chain/train/best_model.pt fi - # sort the options alphabetically - python3 ./chain/train.py \ - --bottleneck-dim $bottleneck_dim \ - --checkpoint=${train_checkpoint:-} \ - --device-id $device_id \ - --dir exp/chain/train \ - --feat-dim $feat_dim \ - --hidden-dim $hidden_dim \ - --is-training true \ - --kernel-size-list "$kernel_size_list" \ - --log-level $log_level \ - --output-dim $output_dim \ - --prefinal-bottleneck-dim $prefinal_bottleneck_dim \ - --subsampling-factor-list "$subsampling_factor_list" \ - --train.cegs-dir exp/chain/merged_egs \ - --train.den-fst exp/chain/den.fst \ - --train.egs-left-context $egs_left_context \ - --train.egs-right-context $egs_right_context \ - --train.l2-regularize 5e-5 \ - --train.leaky-hmm-coefficient 0.1 \ - --train.lr $lr \ - --train.num-epochs $num_epochs \ - --train.valid-cegs-scp exp/chain/egs/valid_diagnostic.scp \ - --train.xent-regularize 0.1 + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + + if [[ $num_gpus -gt 1 ]]; then + echo "$0: training with ddp..." + + export MASTER_ADDR=localhost + export MASTER_PORT=6666 + + for ((i = 0; i < $num_gpus; ++i)); do + # sort the options alphabetically + python3 ./chain/ddp_train.py \ + --bottleneck-dim $bottleneck_dim \ + --checkpoint=${train_checkpoint:-} \ + --device-id $i \ + --dir exp/chain/train \ + --feat-dim $feat_dim \ + --hidden-dim $hidden_dim \ + --is-training true \ + --kernel-size-list "$kernel_size_list" \ + --log-level $log_level \ + --output-dim $output_dim \ + --prefinal-bottleneck-dim $prefinal_bottleneck_dim \ + --subsampling-factor-list "$subsampling_factor_list" \ + --train.cegs-dir exp/chain/merged_egs \ + --train.ddp.world-size $num_gpus \ + --train.den-fst exp/chain/den.fst \ + --train.egs-left-context $egs_left_context \ + --train.egs-right-context $egs_right_context \ + --train.l2-regularize 5e-5 \ + --train.leaky-hmm-coefficient 0.1 \ + --train.lr $lr \ + --train.num-epochs $num_epochs \ + --train.use-ddp true \ + --train.valid-cegs-scp exp/chain/egs/valid_diagnostic.scp \ + --train.xent-regularize 0.1 & + done + wait + else + echo "$0: training with single gpu..." + # sort the options alphabetically + python3 ./chain/train.py \ + --bottleneck-dim $bottleneck_dim \ + --checkpoint=${train_checkpoint:-} \ + --device-id $device_id \ + --dir exp/chain/train \ + --feat-dim $feat_dim \ + --hidden-dim $hidden_dim \ + --is-training true \ + --kernel-size-list "$kernel_size_list" \ + --log-level $log_level \ + --output-dim $output_dim \ + --prefinal-bottleneck-dim $prefinal_bottleneck_dim \ + --subsampling-factor-list "$subsampling_factor_list" \ + --train.cegs-dir exp/chain/merged_egs \ + --train.den-fst exp/chain/den.fst \ + --train.egs-left-context $egs_left_context \ + --train.egs-right-context $egs_right_context \ + --train.l2-regularize 5e-5 \ + --train.leaky-hmm-coefficient 0.1 \ + --train.lr $lr \ + --train.num-epochs $num_epochs \ + --train.use-ddp false \ + --train.valid-cegs-scp exp/chain/egs/valid_diagnostic.scp \ + --train.xent-regularize 0.1 + fi fi if [[ $stage -le 14 ]]; then diff --git a/egs/aishell/s10/run.sh b/egs/aishell/s10/run.sh index f6f3b6e4cb0..2c9d0ba30ab 100755 --- a/egs/aishell/s10/run.sh +++ b/egs/aishell/s10/run.sh @@ -25,7 +25,7 @@ data_url=www.openslr.org/resources/33 nj=30 -stage=14 +stage=0 if [[ $stage -le 0 ]]; then local/download_and_untar.sh $data $data_url data_aishell || exit 1