diff --git a/egs/aishell/s10/chain/model.py b/egs/aishell/s10/chain/model.py index 39d7acb765a..d8d01b29d20 100644 --- a/egs/aishell/s10/chain/model.py +++ b/egs/aishell/s10/chain/model.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Copyright 2019-2020 JD AI, Beijing, China (author: Lu Fan) # Apache 2.0 import logging @@ -20,14 +21,13 @@ def get_chain_model(feat_dim, hidden_dim, bottleneck_dim, time_stride_list, - conv_stride_list, lda_mat_filename=None): model = ChainModel(feat_dim=feat_dim, output_dim=output_dim, lda_mat_filename=lda_mat_filename, hidden_dim=hidden_dim, - time_stride_list=time_stride_list, - conv_stride_list=conv_stride_list) + bottleneck_dim=bottleneck_dim, + time_stride_list=time_stride_list) return model @@ -72,15 +72,14 @@ def __init__(self, lda_mat_filename=None, hidden_dim=1024, bottleneck_dim=128, - time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1], - conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1], + time_stride_list=[1, 1, 1, 0, 3, 3, 3, 3, 3, 3, 3, 3], frame_subsampling_factor=3): super().__init__() # at present, we support only frame_subsampling_factor to be 3 assert frame_subsampling_factor == 3 - - assert len(time_stride_list) == len(conv_stride_list) + self.frame_subsampling_factor = frame_subsampling_factor + self.time_stride_list = time_stride_list num_layers = len(time_stride_list) # tdnn1_affine requires [N, T, C] @@ -93,11 +92,9 @@ def __init__(self, tdnnfs = [] for i in range(num_layers): time_stride = time_stride_list[i] - conv_stride = conv_stride_list[i] layer = FactorizedTDNN(dim=hidden_dim, bottleneck_dim=bottleneck_dim, - time_stride=time_stride, - conv_stride=conv_stride) + time_stride=time_stride) tdnnfs.append(layer) # tdnnfs requires [N, C, T] @@ -105,8 +102,7 @@ def __init__(self, # prefinal_l affine requires [N, C, T] self.prefinal_l = OrthonormalLinear(dim=hidden_dim, - bottleneck_dim=bottleneck_dim * 2, - time_stride=0) + bottleneck_dim=bottleneck_dim * 2) # prefinal_chain requires [N, C, T] self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim, @@ -174,6 +170,13 @@ def forward(self, x): # tdnnf requires input of shape [N, C, T] for i in range(len(self.tdnnfs)): x = self.tdnnfs[i](x) + # stride manually, do not stride context + if self.tdnnfs[i].time_stride == 0: + cur_context = sum(self.time_stride_list[i:]) + x_left = x[:, :, :cur_context] + x_mid = x[:, :, cur_context:-cur_context:self.frame_subsampling_factor] + x_right = x[:, :, -cur_context:] + x = torch.cat([x_left, x_mid, x_right], dim=2) # at this point, x is [N, C, T] diff --git a/egs/aishell/s10/chain/options.py b/egs/aishell/s10/chain/options.py index 5a6e04f9ba7..8506e9c1a23 100644 --- a/egs/aishell/s10/chain/options.py +++ b/egs/aishell/s10/chain/options.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Copyright 2020 JD AI, Beijing, China (author: Lu Fan) # Apache 2.0 import argparse @@ -134,15 +135,8 @@ def _check_args(args): assert args.time_stride_list is not None assert len(args.time_stride_list) > 0 - assert args.conv_stride_list is not None - assert len(args.conv_stride_list) > 0 - args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')] - args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')] - - assert len(args.time_stride_list) == len(args.conv_stride_list) - assert args.log_level in ['debug', 'info', 'warning'] @@ -208,12 +202,6 @@ def get_args(): required=True, type=str) - parser.add_argument('--conv-stride-list', - dest='conv_stride_list', - help='conv stride list', - required=True, - type=str) - parser.add_argument('--log-level', dest='log_level', help='log level. valid values: debug, info, warning', diff --git a/egs/aishell/s10/chain/tdnnf_layer.py b/egs/aishell/s10/chain/tdnnf_layer.py index cf3c5a11862..71f5f8b0c3e 100644 --- a/egs/aishell/s10/chain/tdnnf_layer.py +++ b/egs/aishell/s10/chain/tdnnf_layer.py @@ -53,24 +53,16 @@ def _constrain_orthonormal_internal(M): class OrthonormalLinear(nn.Module): - def __init__(self, dim, bottleneck_dim, time_stride): + def __init__(self, dim, bottleneck_dim, kernel_size=1, dilation=1): super().__init__() - assert time_stride in [0, 1] - # WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer - # and [0, 1] for the second affine layer; - # we use [-1, 0, 1] for the first linear layer if time_stride == 1 - - if time_stride == 0: - kernel_size = 1 - else: - kernel_size = 3 - self.kernel_size = kernel_size + self.dilation = dilation # conv requires [N, C, T] self.conv = nn.Conv1d(in_channels=dim, out_channels=bottleneck_dim, kernel_size=kernel_size, + dilation=dilation, bias=False) def forward(self, x): @@ -116,7 +108,7 @@ def __init__(self, big_dim, small_dim): self.batchnorm1 = nn.BatchNorm1d(num_features=big_dim) self.linear = OrthonormalLinear(dim=big_dim, bottleneck_dim=small_dim, - time_stride=0) + kernel_size=1) self.batchnorm2 = nn.BatchNorm1d(num_features=small_dim) def forward(self, x): @@ -161,29 +153,32 @@ def __init__(self, dim, bottleneck_dim, time_stride, - conv_stride, bypass_scale=0.66): super().__init__() - assert conv_stride in [1, 3] assert abs(bypass_scale) <= 1 self.bypass_scale = bypass_scale + self.time_stride = time_stride - self.conv_stride = conv_stride + if time_stride > 0: + kernel_size, dilation = 2, time_stride + else: + kernel_size, dilation = 1, 1 # linear requires [N, C, T] self.linear = OrthonormalLinear(dim=dim, bottleneck_dim=bottleneck_dim, - time_stride=time_stride) + kernel_size=kernel_size, + dilation=dilation) # affine requires [N, C, T] # WARNING(fangjun): we do not use nn.Linear here # since we want to use `stride` self.affine = nn.Conv1d(in_channels=bottleneck_dim, out_channels=dim, - kernel_size=1, - stride=conv_stride) + kernel_size=kernel_size, + dilation=dilation) # batchnorm requires [N, C, T] self.batchnorm = nn.BatchNorm1d(num_features=dim) @@ -213,10 +208,11 @@ def forward(self, x): # TODO(fangjun): implement GeneralDropoutComponent in PyTorch - if self.linear.kernel_size == 3: - x = self.bypass_scale * input_x[:, :, 1:-1:self.conv_stride] + x + # at this point, x is [N, C, T] + if self.linear.kernel_size == 2: + x = self.bypass_scale * input_x[:, :, self.linear.dilation:-self.linear.dilation:1] + x else: - x = self.bypass_scale * input_x[:, :, ::self.conv_stride] + x + x = self.bypass_scale * input_x[:, :, ::1] + x return x def constrain_orthonormal(self): @@ -257,8 +253,7 @@ def compute_loss(M): model = FactorizedTDNN(dim=1024, bottleneck_dim=128, - time_stride=1, - conv_stride=3) + time_stride=1) loss = [] model.constrain_orthonormal() loss.append( @@ -278,40 +273,29 @@ def _test_factorized_tdnn(): N = 1 T = 10 C = 4 - - # case 0: time_stride == 1, conv_stride == 1 + # https://pytorch.org/docs/stable/nn.html?highlight=conv1d#torch.nn.Conv1d + # T_out = math.ceil((T + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + # case 0: time_stride == 1, kernel_size==2, dilation = 1 model = FactorizedTDNN(dim=C, bottleneck_dim=2, - time_stride=1, - conv_stride=1) + time_stride=1) x = torch.arange(N * T * C).reshape(N, C, T).float() y = model(x) assert y.size(2) == T - 2 - # case 1: time_stride == 0, conv_stride == 1 + # case 1: time_stride == 0, kernel_size == 1, dilation == 1 model = FactorizedTDNN(dim=C, bottleneck_dim=2, - time_stride=0, - conv_stride=1) + time_stride=0) y = model(x) assert y.size(2) == T - # case 2: time_stride == 1, conv_stride == 3 + # case 2: time_stride == 3, kernel_size == 2, dilation = 3 model = FactorizedTDNN(dim=C, bottleneck_dim=2, - time_stride=1, - conv_stride=3) + time_stride=3) y = model(x) - assert y.size(2) == math.ceil((T - 2) / 3) - - # case 3: time_stride == 0, conv_stride == 3 - model = FactorizedTDNN(dim=C, - bottleneck_dim=2, - time_stride=0, - conv_stride=3) - y = model(x) - assert y.size(2) == math.ceil(T / 3) - + assert y.size(2) == math.ceil(math.ceil((T - 3)) - 3) if __name__ == '__main__': torch.manual_seed(20200130) diff --git a/egs/aishell/s10/chain/train.py b/egs/aishell/s10/chain/train.py index 1f5c6824c97..31ba0315942 100644 --- a/egs/aishell/s10/chain/train.py +++ b/egs/aishell/s10/chain/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Copyright 2019-2020 JD AI, Beijing, China (author: Lu Fan) # Apache 2.0 import logging @@ -100,15 +101,28 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion, total_objf / total_weight, total_frames, objf_l2_term_weight[0].item() / objf_l2_term_weight[2].item(), num_frames, current_epoch)) + log_norm = ["{}: {:.4f}".format(name, torch.norm(parms)) \ + for name, parms in model.named_parameters() \ + if "affine" in name or "linear" in name] + logging.info("Process {}/{}({:.6f}%) l2-norm is:[ {} ]".format(batch_idx, + len(dataloader), float(batch_idx) / len(dataloader) * 100), + " ".join(log_norm)) if batch_idx % 100 == 0: + current_iter = batch_idx + current_epoch * len(dataloader) tf_writer.add_scalar('train/global_average_objf', total_objf / total_weight, - batch_idx + current_epoch * len(dataloader)) + current_iter) tf_writer.add_scalar( 'train/current_batch_average_objf', objf_l2_term_weight[0].item() / objf_l2_term_weight[2].item(), - batch_idx + current_epoch * len(dataloader)) + current_iter) + for name, parms in model.named_parameters(): + tf_writer.add_histogram(f'train/norm/{name}', + parms.clone().cpu().data.numpy(), + current_iter) + tf_writer.add_scalar(f'train/l2_norm/{name}', torch.norm(parms), + current_iter) return total_objf / total_weight @@ -142,8 +156,7 @@ def main(): lda_mat_filename=args.lda_mat_filename, hidden_dim=args.hidden_dim, bottleneck_dim=args.bottleneck_dim, - time_stride_list=args.time_stride_list, - conv_stride_list=args.conv_stride_list) + time_stride_list=args.time_stride_list) start_epoch = 0 num_epochs = args.num_epochs diff --git a/egs/aishell/s10/local/run_chain.sh b/egs/aishell/s10/local/run_chain.sh index 06b5d47e89f..350e5fe9f06 100755 --- a/egs/aishell/s10/local/run_chain.sh +++ b/egs/aishell/s10/local/run_chain.sh @@ -1,6 +1,7 @@ #!/bin/bash # Copyright 2019 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Copyright 2020 JD AI, Beijing, China (author: Lu Fan) # Apache 2.0 set -e @@ -33,7 +34,6 @@ lr=1e-3 hidden_dim=1024 bottleneck_dim=128 time_stride_list="1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list -conv_stride_list="1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list log_level=info # valid values: debug, info, warning @@ -164,7 +164,6 @@ if [[ $stage -le 8 ]]; then python3 ./chain/train.py \ --bottleneck-dim $bottleneck_dim \ --checkpoint=${train_checkpoint:-} \ - --conv-stride-list "$conv_stride_list" \ --device-id $device_id \ --dir exp/chain/train \ --feat-dim $feat_dim \ @@ -194,7 +193,6 @@ if [[ $stage -le 9 ]]; then python3 ./chain/inference.py \ --bottleneck-dim $bottleneck_dim \ --checkpoint $inference_checkpoint \ - --conv-stride-list "$conv_stride_list" \ --device-id $device_id \ --dir exp/chain/inference/$x \ --feat-dim $feat_dim \