Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions egs/aishell/s10/chain/egs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Apache 2.0

import glob
import os

import numpy as np
import torch
Expand All @@ -17,9 +18,12 @@
from common import splice_feats


def get_egs_dataloader(egs_dir, egs_left_context, egs_right_context):
def get_egs_dataloader(egs_dir_or_scp,
egs_left_context,
egs_right_context,
shuffle=True):

dataset = NnetChainExampleDataset(egs_dir=egs_dir)
dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir_or_scp)
frame_subsampling_factor = 3

# we have merged egs offline, so batch size is 1
Expand All @@ -32,6 +36,7 @@ def get_egs_dataloader(egs_dir, egs_left_context, egs_right_context):

dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0,
collate_fn=collate_fn)
return dataloader
Expand All @@ -50,11 +55,16 @@ def read_nnet_chain_example(rxfilename):

class NnetChainExampleDataset(Dataset):

def __init__(self, egs_dir):
def __init__(self, egs_dir_or_scp):
'''
We assume that there exist many cegs.*.scp files inside egs_dir
If egs_dir_or_scp is a directory, we assume that there exist many cegs.*.scp files
inside egs_dir_or_scp.
'''
self.scps = glob.glob('{}/cegs.*.scp'.format(egs_dir))
if os.path.isdir(egs_dir_or_scp):
self.scps = glob.glob('{}/cegs.*.scp'.format(egs_dir_or_scp))
else:
self.scps = [egs_dir_or_scp]

assert len(self.scps) > 0
self.items = list()
for scp in self.scps:
Expand Down Expand Up @@ -171,7 +181,7 @@ def __call__(self, batch):

def _test_nnet_chain_example_dataset():
egs_dir = 'exp/chain/merged_egs'
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir)
egs_left_context = 29
egs_right_context = 29
frame_subsampling_factor = 3
Expand Down
16 changes: 9 additions & 7 deletions egs/aishell/s10/chain/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def main():
else:
device = torch.device('cuda', args.device_id)

model = get_chain_model(feat_dim=args.feat_dim,
output_dim=args.output_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
bottleneck_dim=args.bottleneck_dim,
time_stride_list=args.time_stride_list,
conv_stride_list=args.conv_stride_list)
model = get_chain_model(
feat_dim=args.feat_dim,
output_dim=args.output_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
bottleneck_dim=args.bottleneck_dim,
prefinal_bottleneck_dim=args.prefinal_bottleneck_dim,
kernel_size_list=args.kernel_size_list,
subsampling_factor_list=args.subsampling_factor_list)

load_checkpoint(args.checkpoint, model)

Expand Down
54 changes: 33 additions & 21 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@ def get_chain_model(feat_dim,
output_dim,
hidden_dim,
bottleneck_dim,
time_stride_list,
conv_stride_list,
prefinal_bottleneck_dim,
kernel_size_list,
subsampling_factor_list,
lda_mat_filename=None):
model = ChainModel(feat_dim=feat_dim,
output_dim=output_dim,
lda_mat_filename=lda_mat_filename,
hidden_dim=hidden_dim,
time_stride_list=time_stride_list,
conv_stride_list=conv_stride_list)
bottleneck_dim=bottleneck_dim,
prefinal_bottleneck_dim=prefinal_bottleneck_dim,
kernel_size_list=kernel_size_list,
subsampling_factor_list=subsampling_factor_list)
return model


Expand Down Expand Up @@ -72,55 +75,58 @@ def __init__(self,
lda_mat_filename=None,
hidden_dim=1024,
bottleneck_dim=128,
time_stride_list=[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1],
conv_stride_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1],
prefinal_bottleneck_dim=256,
kernel_size_list=[2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2],
subsampling_factor_list=[1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1],
frame_subsampling_factor=3):
super().__init__()

# at present, we support only frame_subsampling_factor to be 3
assert frame_subsampling_factor == 3

assert len(time_stride_list) == len(conv_stride_list)
num_layers = len(time_stride_list)
assert len(kernel_size_list) == len(subsampling_factor_list)
num_layers = len(kernel_size_list)

# tdnn1_affine requires [N, T, C]
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
out_features=hidden_dim)

# tdnn1_batchnorm requires [N, C, T]
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim)
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim,
affine=False)

tdnnfs = []
for i in range(num_layers):
time_stride = time_stride_list[i]
conv_stride = conv_stride_list[i]
kernel_size = kernel_size_list[i]
subsampling_factor = subsampling_factor_list[i]
layer = FactorizedTDNN(dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
time_stride=time_stride,
conv_stride=conv_stride)
kernel_size=kernel_size,
subsampling_factor=subsampling_factor)
tdnnfs.append(layer)

# tdnnfs requires [N, C, T]
self.tdnnfs = nn.ModuleList(tdnnfs)

# prefinal_l affine requires [N, C, T]
self.prefinal_l = OrthonormalLinear(dim=hidden_dim,
bottleneck_dim=bottleneck_dim * 2,
time_stride=0)
self.prefinal_l = OrthonormalLinear(
dim=hidden_dim,
bottleneck_dim=prefinal_bottleneck_dim,
kernel_size=1)

# prefinal_chain requires [N, C, T]
self.prefinal_chain = PrefinalLayer(big_dim=hidden_dim,
small_dim=bottleneck_dim * 2)
small_dim=prefinal_bottleneck_dim)

# output_affine requires [N, T, C]
self.output_affine = nn.Linear(in_features=bottleneck_dim * 2,
self.output_affine = nn.Linear(in_features=prefinal_bottleneck_dim,
out_features=output_dim)

# prefinal_xent requires [N, C, T]
self.prefinal_xent = PrefinalLayer(big_dim=hidden_dim,
small_dim=bottleneck_dim * 2)
small_dim=prefinal_bottleneck_dim)

self.output_xent_affine = nn.Linear(in_features=bottleneck_dim * 2,
self.output_xent_affine = nn.Linear(in_features=prefinal_bottleneck_dim,
out_features=output_dim)

if lda_mat_filename:
Expand All @@ -130,7 +136,8 @@ def __init__(self,
self.has_LDA = True
else:
logging.info('replace LDA with BatchNorm')
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3)
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3,
affine=False)
self.has_LDA = False

def forward(self, x):
Expand Down Expand Up @@ -211,9 +218,14 @@ def constrain_orthonormal(self):


if __name__ == '__main__':
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
feat_dim = 43
output_dim = 4344
model = ChainModel(feat_dim=feat_dim, output_dim=output_dim)
logging.info(model)
N = 1
T = 150 + 27 + 27
C = feat_dim * 3
Expand Down
55 changes: 41 additions & 14 deletions egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def _set_training_args(parser):
help='cegs dir containing comibined cegs.*.scp',
type=str)

parser.add_argument('--train.valid-cegs-scp',
dest='valid_cegs_scp',
help='validation cegs scp',
type=str)

parser.add_argument('--train.den-fst',
dest='den_fst_filename',
help='denominator fst filename',
Expand Down Expand Up @@ -84,9 +89,20 @@ def _set_training_args(parser):
help='l2 regularize',
type=float)

parser.add_argument('--train.xent-regularize',
dest='xent_regularize',
help='xent regularize',
type=float)

parser.add_argument('--train.leaky-hmm-coefficient',
dest='leaky_hmm_coefficient',
help='leaky hmm coefficient',
type=float)


def _check_training_args(args):
assert os.path.isdir(args.cegs_dir)
assert os.path.isfile(args.valid_cegs_scp)

assert os.path.isfile(args.den_fst_filename)

Expand All @@ -95,7 +111,9 @@ def _check_training_args(args):

assert args.num_epochs > 0
assert args.learning_rate > 0
assert args.l2_regularize > 0
assert args.l2_regularize >= 0
assert args.xent_regularize >= 0
assert args.leaky_hmm_coefficient >= 0

if args.checkpoint:
assert os.path.exists(args.checkpoint)
Expand Down Expand Up @@ -130,18 +148,21 @@ def _check_args(args):
assert args.output_dim > 0
assert args.hidden_dim > 0
assert args.bottleneck_dim > 0
assert args.prefinal_bottleneck_dim > 0

assert args.time_stride_list is not None
assert len(args.time_stride_list) > 0
assert args.kernel_size_list is not None
assert len(args.kernel_size_list) > 0

assert args.conv_stride_list is not None
assert len(args.conv_stride_list) > 0
assert args.subsampling_factor_list is not None
assert len(args.subsampling_factor_list) > 0

args.time_stride_list = [int(k) for k in args.time_stride_list.split(', ')]
args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]

args.conv_stride_list = [int(k) for k in args.conv_stride_list.split(', ')]
args.subsampling_factor_list = [
int(k) for k in args.subsampling_factor_list.split(', ')
]

assert len(args.time_stride_list) == len(args.conv_stride_list)
assert len(args.kernel_size_list) == len(args.subsampling_factor_list)

assert args.log_level in ['debug', 'info', 'warning']

Expand Down Expand Up @@ -202,15 +223,21 @@ def get_args():
required=True,
type=int)

parser.add_argument('--time-stride-list',
dest='time_stride_list',
help='time stride list',
parser.add_argument('--prefinal-bottleneck-dim',
dest='prefinal_bottleneck_dim',
help='nn prefinal bottleneck dimension',
required=True,
type=int)

parser.add_argument('--kernel-size-list',
dest='kernel_size_list',
help='kernel_size_list',
required=True,
type=str)

parser.add_argument('--conv-stride-list',
dest='conv_stride_list',
help='conv stride list',
parser.add_argument('--subsampling-factor-list',
dest='subsampling_factor_list',
help='subsampling_factor_list',
required=True,
type=str)

Expand Down
Loading