1515import torch
1616import torch .optim as optim
1717from torch .nn .utils import clip_grad_value_
18- from torch .optim .lr_scheduler import MultiStepLR
1918from torch .utils .tensorboard import SummaryWriter
2019
2120import kaldi
3231from options import get_args
3332
3433
35- def train_one_epoch (dataloader , model , device , optimizer , criterion ,
36- current_epoch , opts , den_graph , tf_writer ):
34+ def get_validation_objf (dataloader , model , device , criterion , opts , den_graph ):
35+ total_objf = 0.
36+ total_weight = 0.
37+ total_frames = 0. # for display only
38+
39+ model .eval ()
40+
41+ for batch_idx , batch in enumerate (dataloader ):
42+ key_list , feature_list , supervision_list = batch
43+
44+ assert len (key_list ) == len (feature_list ) == len (supervision_list )
45+ batch_size = len (key_list )
46+
47+ for n in range (batch_size ):
48+ feats = feature_list [n ]
49+ assert feats .ndim == 3
50+
51+ # at this point, feats is [N, T, C]
52+ feats = feats .to (device )
53+
54+ with torch .no_grad ():
55+ nnet_output , xent_output = model (feats )
56+
57+ # at this point, nnet_output is: [N, T, C]
58+ # refer to kaldi/src/chain/chain-training.h
59+ # the output should be organized as
60+ # [all sequences for frame 0]
61+ # [all sequences for frame 1]
62+ # [etc.]
63+ nnet_output = nnet_output .permute (1 , 0 , 2 )
64+ # at this point, nnet_output is: [T, N, C]
65+ nnet_output = nnet_output .contiguous ().view (- 1 ,
66+ nnet_output .shape [- 1 ])
67+
68+ # at this point, xent_output is: [N, T, C]
69+ xent_output = xent_output .permute (1 , 0 , 2 )
70+ # at this point, xent_output is: [T, N, C]
71+ xent_output = xent_output .contiguous ().view (- 1 ,
72+ xent_output .shape [- 1 ])
73+ objf_l2_term_weight = criterion (opts , den_graph ,
74+ supervision_list [n ], nnet_output ,
75+ xent_output )
76+ objf = objf_l2_term_weight [0 ]
77+
78+ objf_l2_term_weight = objf_l2_term_weight .cpu ()
79+
80+ total_objf += objf_l2_term_weight [0 ].item ()
81+ total_weight += objf_l2_term_weight [2 ].item ()
82+
83+ num_frames = nnet_output .shape [0 ]
84+ total_frames += num_frames
85+
86+ return total_objf , total_weight , total_frames
87+
88+
89+ def train_one_epoch (dataloader , valid_dataloader , model , device , optimizer ,
90+ criterion , current_epoch , opts , den_graph , tf_writer ):
3791 model .train ()
3892
3993 total_objf = 0.
@@ -75,8 +129,8 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion,
75129 optimizer .zero_grad ()
76130 objf .backward ()
77131
78- # TODO(fangjun): how to choose this value or do we need this ?
79132 clip_grad_value_ (model .parameters (), 5.0 )
133+
80134 optimizer .step ()
81135
82136 objf_l2_term_weight = objf_l2_term_weight .detach ().cpu ()
@@ -101,6 +155,25 @@ def train_one_epoch(dataloader, model, device, optimizer, criterion,
101155 objf_l2_term_weight [0 ].item () /
102156 objf_l2_term_weight [2 ].item (), num_frames , current_epoch ))
103157
158+ if batch_idx % 500 == 0 :
159+ total_valid_objf , total_valid_weight , total_valid_frames = get_validation_objf (
160+ dataloader = valid_dataloader ,
161+ model = model ,
162+ device = device ,
163+ criterion = criterion ,
164+ opts = opts ,
165+ den_graph = den_graph )
166+
167+ model .train ()
168+
169+ logging .info (
170+ 'Validation average objf: {:.6f} over {} frames' .format (
171+ total_valid_objf / total_valid_weight , total_valid_frames ))
172+
173+ tf_writer .add_scalar ('train/global_valid_average_objf' ,
174+ total_valid_objf / total_valid_weight ,
175+ batch_idx + current_epoch * len (dataloader ))
176+
104177 if batch_idx % 100 == 0 :
105178 tf_writer .add_scalar ('train/global_average_objf' ,
106179 total_objf / total_weight ,
@@ -145,10 +218,10 @@ def main():
145218
146219 den_fst = fst .StdVectorFst .Read (args .den_fst_filename )
147220
148- # TODO(fangjun): pass these options from commandline
149221 opts = chain .ChainTrainingOptions ()
150- opts .l2_regularize = 5e-4
151- opts .leaky_hmm_coefficient = 0.1
222+ opts .l2_regularize = args .l2_regularize
223+ opts .xent_regularize = args .xent_regularize
224+ opts .leaky_hmm_coefficient = args .leaky_hmm_coefficient
152225
153226 den_graph = chain .DenominatorGraph (fst = den_fst , num_pdfs = args .output_dim )
154227
@@ -179,16 +252,21 @@ def main():
179252
180253 model .to (device )
181254
182- dataloader = get_egs_dataloader (egs_dir = args .cegs_dir ,
255+ dataloader = get_egs_dataloader (egs_dir_or_scp = args .cegs_dir ,
183256 egs_left_context = args .egs_left_context ,
184257 egs_right_context = args .egs_right_context ,
185258 shuffle = True )
186259
260+ valid_dataloader = get_egs_dataloader (
261+ egs_dir_or_scp = args .valid_cegs_scp ,
262+ egs_left_context = args .egs_left_context ,
263+ egs_right_context = args .egs_right_context ,
264+ shuffle = False )
265+
187266 optimizer = optim .Adam (model .parameters (),
188267 lr = learning_rate ,
189- weight_decay = args . l2_regularize )
268+ weight_decay = 5e-4 )
190269
191- scheduler = MultiStepLR (optimizer , milestones = [1 , 2 , 3 , 4 , 5 ], gamma = 0.5 )
192270 criterion = KaldiChainObjfFunction .apply
193271
194272 tf_writer = SummaryWriter (log_dir = '{}/tensorboard' .format (args .dir ))
@@ -198,12 +276,17 @@ def main():
198276 best_epoch_info_filename = os .path .join (args .dir , 'best-epoch-info' )
199277 try :
200278 for epoch in range (start_epoch , args .num_epochs ):
201- learning_rate = scheduler .get_lr ()[0 ]
279+ learning_rate = 1e-3 * pow (0.4 , epoch )
280+ for param_group in optimizer .param_groups :
281+ param_group ['lr' ] = learning_rate
282+
202283 logging .info ('epoch {}, learning rate {}' .format (
203284 epoch , learning_rate ))
285+
204286 tf_writer .add_scalar ('learning_rate' , learning_rate , epoch )
205287
206288 objf = train_one_epoch (dataloader = dataloader ,
289+ valid_dataloader = valid_dataloader ,
207290 model = model ,
208291 device = device ,
209292 optimizer = optimizer ,
@@ -212,7 +295,6 @@ def main():
212295 opts = opts ,
213296 den_graph = den_graph ,
214297 tf_writer = tf_writer )
215- scheduler .step ()
216298
217299 if best_objf is None :
218300 best_objf = objf
0 commit comments