Skip to content

Commit 793f572

Browse files
committed
compute validation objf and switch back to [-1, 0, 1] for the linear
layer.
1 parent 8c32349 commit 793f572

File tree

5 files changed

+139
-28
lines changed

5 files changed

+139
-28
lines changed

egs/aishell/s10/chain/egs_dataset.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Apache 2.0
55

66
import glob
7+
import os
78

89
import numpy as np
910
import torch
@@ -17,12 +18,12 @@
1718
from common import splice_feats
1819

1920

20-
def get_egs_dataloader(egs_dir,
21+
def get_egs_dataloader(egs_dir_or_scp,
2122
egs_left_context,
2223
egs_right_context,
2324
shuffle=True):
2425

25-
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
26+
dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir_or_scp)
2627
frame_subsampling_factor = 3
2728

2829
# we have merged egs offline, so batch size is 1
@@ -54,11 +55,16 @@ def read_nnet_chain_example(rxfilename):
5455

5556
class NnetChainExampleDataset(Dataset):
5657

57-
def __init__(self, egs_dir):
58+
def __init__(self, egs_dir_or_scp):
5859
'''
59-
We assume that there exist many cegs.*.scp files inside egs_dir
60+
If egs_dir_or_scp is a directory, we assume that there exist many cegs.*.scp files
61+
inside egs_dir_or_scp.
6062
'''
61-
self.scps = glob.glob('{}/cegs.*.scp'.format(egs_dir))
63+
if os.path.isdir(egs_dir_or_scp):
64+
self.scps = glob.glob('{}/cegs.*.scp'.format(egs_dir_or_scp))
65+
else:
66+
self.scps = [egs_dir_or_scp]
67+
6268
assert len(self.scps) > 0
6369
self.items = list()
6470
for scp in self.scps:
@@ -175,7 +181,7 @@ def __call__(self, batch):
175181

176182
def _test_nnet_chain_example_dataset():
177183
egs_dir = 'exp/chain/merged_egs'
178-
dataset = NnetChainExampleDataset(egs_dir=egs_dir)
184+
dataset = NnetChainExampleDataset(egs_dir_or_scp=egs_dir)
179185
egs_left_context = 29
180186
egs_right_context = 29
181187
frame_subsampling_factor = 3

egs/aishell/s10/chain/options.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def _set_training_args(parser):
5454
help='cegs dir containing comibined cegs.*.scp',
5555
type=str)
5656

57+
parser.add_argument('--train.valid-cegs-scp',
58+
dest='valid_cegs_scp',
59+
help='validation cegs scp',
60+
type=str)
61+
5762
parser.add_argument('--train.den-fst',
5863
dest='den_fst_filename',
5964
help='denominator fst filename',
@@ -84,9 +89,20 @@ def _set_training_args(parser):
8489
help='l2 regularize',
8590
type=float)
8691

92+
parser.add_argument('--train.xent-regularize',
93+
dest='xent_regularize',
94+
help='xent regularize',
95+
type=float)
96+
97+
parser.add_argument('--train.leaky-hmm-coefficient',
98+
dest='leaky_hmm_coefficient',
99+
help='leaky hmm coefficient',
100+
type=float)
101+
87102

88103
def _check_training_args(args):
89104
assert os.path.isdir(args.cegs_dir)
105+
assert os.path.isfile(args.valid_cegs_scp)
90106

91107
assert os.path.isfile(args.den_fst_filename)
92108

@@ -95,7 +111,9 @@ def _check_training_args(args):
95111

96112
assert args.num_epochs > 0
97113
assert args.learning_rate > 0
98-
assert args.l2_regularize > 0
114+
assert args.l2_regularize >= 0
115+
assert args.xent_regularize >= 0
116+
assert args.leaky_hmm_coefficient >= 0
99117

100118
if args.checkpoint:
101119
assert os.path.exists(args.checkpoint)
@@ -140,7 +158,9 @@ def _check_args(args):
140158

141159
args.kernel_size_list = [int(k) for k in args.kernel_size_list.split(', ')]
142160

143-
args.subsampling_factor_list = [int(k) for k in args.subsampling_factor_list.split(', ')]
161+
args.subsampling_factor_list = [
162+
int(k) for k in args.subsampling_factor_list.split(', ')
163+
]
144164

145165
assert len(args.kernel_size_list) == len(args.subsampling_factor_list)
146166

egs/aishell/s10/chain/tdnnf_layer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def __init__(self,
175175
# since we want to use `stride`
176176
self.affine = nn.Conv1d(in_channels=bottleneck_dim,
177177
out_channels=dim,
178-
kernel_size=kernel_size,
178+
kernel_size=1,
179179
stride=subsampling_factor)
180180

181181
# batchnorm requires [N, C, T]
@@ -204,7 +204,7 @@ def forward(self, x):
204204

205205
# TODO(fangjun): implement GeneralDropoutComponent in PyTorch
206206

207-
if self.linear.kernel_size == 2:
207+
if self.linear.kernel_size == 3:
208208
x = self.bypass_scale * input_x[:, :, self.s:-self.s:self.s] + x
209209
else:
210210
x = self.bypass_scale * input_x[:, :, ::self.s] + x
@@ -248,7 +248,7 @@ def compute_loss(M):
248248

249249
model = FactorizedTDNN(dim=1024,
250250
bottleneck_dim=128,
251-
kernel_size=2,
251+
kernel_size=3,
252252
subsampling_factor=1)
253253
loss = []
254254
model.constrain_orthonormal()
@@ -279,10 +279,10 @@ def _test_factorized_tdnn():
279279
y = model(x)
280280
assert y.size(2) == T
281281

282-
# case 1: kernel_size == 2, subsampling_factor == 1
282+
# case 1: kernel_size == 3, subsampling_factor == 1
283283
model = FactorizedTDNN(dim=C,
284284
bottleneck_dim=2,
285-
kernel_size=2,
285+
kernel_size=3,
286286
subsampling_factor=1)
287287
y = model(x)
288288
assert y.size(2) == T - 2

egs/aishell/s10/chain/train.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch
1616
import torch.optim as optim
1717
from torch.nn.utils import clip_grad_value_
18-
from torch.optim.lr_scheduler import MultiStepLR
1918
from torch.utils.tensorboard import SummaryWriter
2019

2120
import kaldi
@@ -32,8 +31,63 @@
3231
from options import get_args
3332

3433

35-
def train_one_epoch(dataloader, model, device, optimizer, criterion,
36-
current_epoch, opts, den_graph, tf_writer):
34+
def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
35+
total_objf = 0.
36+
total_weight = 0.
37+
total_frames = 0. # for display only
38+
39+
model.eval()
40+
41+
for batch_idx, batch in enumerate(dataloader):
42+
key_list, feature_list, supervision_list = batch
43+
44+
assert len(key_list) == len(feature_list) == len(supervision_list)
45+
batch_size = len(key_list)
46+
47+
for n in range(batch_size):
48+
feats = feature_list[n]
49+
assert feats.ndim == 3
50+
51+
# at this point, feats is [N, T, C]
52+
feats = feats.to(device)
53+
54+
with torch.no_grad():
55+
nnet_output, xent_output = model(feats)
56+
57+
# at this point, nnet_output is: [N, T, C]
58+
# refer to kaldi/src/chain/chain-training.h
59+
# the output should be organized as
60+
# [all sequences for frame 0]
61+
# [all sequences for frame 1]
62+
# [etc.]
63+
nnet_output = nnet_output.permute(1, 0, 2)
64+
# at this point, nnet_output is: [T, N, C]
65+
nnet_output = nnet_output.contiguous().view(-1,
66+
nnet_output.shape[-1])
67+
68+
# at this point, xent_output is: [N, T, C]
69+
xent_output = xent_output.permute(1, 0, 2)
70+
# at this point, xent_output is: [T, N, C]
71+
xent_output = xent_output.contiguous().view(-1,
72+
xent_output.shape[-1])
73+
objf_l2_term_weight = criterion(opts, den_graph,
74+
supervision_list[n], nnet_output,
75+
xent_output)
76+
objf = objf_l2_term_weight[0]
77+
78+
objf_l2_term_weight = objf_l2_term_weight.cpu()
79+
80+
total_objf += objf_l2_term_weight[0].item()
81+
total_weight += objf_l2_term_weight[2].item()
82+
83+
num_frames = nnet_output.shape[0]
84+
total_frames += num_frames
85+
86+
return total_objf, total_weight, total_frames
87+
88+
89+
def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
90+
criterion, current_epoch, opts, den_graph, tf_writer):
3791
model.train()
3892

3993
total_objf = 0.
@@ -75,8 +129,8 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion,
75129
optimizer.zero_grad()
76130
objf.backward()
77131

78-
# TODO(fangjun): how to choose this value or do we need this ?
79132
clip_grad_value_(model.parameters(), 5.0)
133+
80134
optimizer.step()
81135

82136
objf_l2_term_weight = objf_l2_term_weight.detach().cpu()
@@ -101,6 +155,25 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion,
101155
objf_l2_term_weight[0].item() /
102156
objf_l2_term_weight[2].item(), num_frames, current_epoch))
103157

158+
if batch_idx % 500 == 0:
159+
total_valid_objf, total_valid_weight, total_valid_frames = get_validation_objf(
160+
dataloader=valid_dataloader,
161+
model=model,
162+
device=device,
163+
criterion=criterion,
164+
opts=opts,
165+
den_graph=den_graph)
166+
167+
model.train()
168+
169+
logging.info(
170+
'Validation average objf: {:.6f} over {} frames'.format(
171+
total_valid_objf / total_valid_weight, total_valid_frames))
172+
173+
tf_writer.add_scalar('train/global_valid_average_objf',
174+
total_valid_objf / total_valid_weight,
175+
batch_idx + current_epoch * len(dataloader))
176+
104177
if batch_idx % 100 == 0:
105178
tf_writer.add_scalar('train/global_average_objf',
106179
total_objf / total_weight,
@@ -145,10 +218,10 @@ def main():
145218

146219
den_fst = fst.StdVectorFst.Read(args.den_fst_filename)
147220

148-
# TODO(fangjun): pass these options from commandline
149221
opts = chain.ChainTrainingOptions()
150-
opts.l2_regularize = 5e-4
151-
opts.leaky_hmm_coefficient = 0.1
222+
opts.l2_regularize = args.l2_regularize
223+
opts.xent_regularize = args.xent_regularize
224+
opts.leaky_hmm_coefficient = args.leaky_hmm_coefficient
152225

153226
den_graph = chain.DenominatorGraph(fst=den_fst, num_pdfs=args.output_dim)
154227

@@ -179,16 +252,21 @@ def main():
179252

180253
model.to(device)
181254

182-
dataloader = get_egs_dataloader(egs_dir=args.cegs_dir,
255+
dataloader = get_egs_dataloader(egs_dir_or_scp=args.cegs_dir,
183256
egs_left_context=args.egs_left_context,
184257
egs_right_context=args.egs_right_context,
185258
shuffle=True)
186259

260+
valid_dataloader = get_egs_dataloader(
261+
egs_dir_or_scp=args.valid_cegs_scp,
262+
egs_left_context=args.egs_left_context,
263+
egs_right_context=args.egs_right_context,
264+
shuffle=False)
265+
187266
optimizer = optim.Adam(model.parameters(),
188267
lr=learning_rate,
189-
weight_decay=args.l2_regularize)
268+
weight_decay=5e-4)
190269

191-
scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4, 5], gamma=0.5)
192270
criterion = KaldiChainObjfFunction.apply
193271

194272
tf_writer = SummaryWriter(log_dir='{}/tensorboard'.format(args.dir))
@@ -198,12 +276,17 @@ def main():
198276
best_epoch_info_filename = os.path.join(args.dir, 'best-epoch-info')
199277
try:
200278
for epoch in range(start_epoch, args.num_epochs):
201-
learning_rate = scheduler.get_lr()[0]
279+
learning_rate = 1e-3 * pow(0.4, epoch)
280+
for param_group in optimizer.param_groups:
281+
param_group['lr'] = learning_rate
282+
202283
logging.info('epoch {}, learning rate {}'.format(
203284
epoch, learning_rate))
285+
204286
tf_writer.add_scalar('learning_rate', learning_rate, epoch)
205287

206288
objf = train_one_epoch(dataloader=dataloader,
289+
valid_dataloader=valid_dataloader,
207290
model=model,
208291
device=device,
209292
optimizer=optimizer,
@@ -212,7 +295,6 @@ def main():
212295
opts=opts,
213296
den_graph=den_graph,
214297
tf_writer=tf_writer)
215-
scheduler.step()
216298

217299
if best_objf is None:
218300
best_objf = objf

egs/aishell/s10/local/run_chain.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ lr=1e-3
3333
hidden_dim=1024
3434
bottleneck_dim=128
3535
prefinal_bottleneck_dim=256
36-
kernel_size_list="2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2" # comma separated list
36+
kernel_size_list="3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3" # comma separated list
3737
subsampling_factor_list="1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1" # comma separated list
3838

3939
log_level=info # valid values: debug, info, warning
@@ -179,9 +179,12 @@ if [[ $stage -le 8 ]]; then
179179
--train.den-fst exp/chain/den.fst \
180180
--train.egs-left-context $egs_left_context \
181181
--train.egs-right-context $egs_right_context \
182-
--train.l2-regularize 5e-4 \
182+
--train.l2-regularize 5e-5 \
183+
--train.leaky-hmm-coefficient 0.1 \
183184
--train.lr $lr \
184-
--train.num-epochs $num_epochs
185+
--train.num-epochs $num_epochs \
186+
--train.valid-cegs-scp exp/chain/egs/valid_diagnostic.scp \
187+
--train.xent-regularize 0.1
185188
fi
186189

187190
if [[ $stage -le 9 ]]; then

0 commit comments

Comments
 (0)