1717import torch .optim as optim
1818from torch .nn .parallel import DistributedDataParallel as DDP
1919from torch .nn .utils import clip_grad_value_
20+ from torch .utils .data .distributed import DistributedSampler
2021from torch .utils .tensorboard import SummaryWriter
2122
2223import kaldi
3031from common import setup_logger
3132from device_utils import allocate_gpu_devices
3233from egs_dataset import get_egs_dataloader
34+ from libs .nnet3 .train .dropout_schedule import _get_dropout_proportions
3335from model import get_chain_model
3436from options import get_args
3537
36- def get_objf (batch , model , device , criterion , opts , den_graph , training , optimizer = None ):
38+ def get_objf (batch , model , device , criterion , opts , den_graph , training , optimizer = None , dropout = 0. ):
3739 total_objf = 0.
3840 total_weight = 0.
3941 total_frames = 0. # for display only
@@ -48,7 +50,7 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
4850 # at this point, feats is [N, T, C]
4951 feats = feats .to (device )
5052 if training :
51- nnet_output , xent_output = model (feats )
53+ nnet_output , xent_output = model (feats , dropout = dropout )
5254 else :
5355 with torch .no_grad ():
5456 nnet_output , xent_output = model (feats )
@@ -106,17 +108,20 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
106108 return total_objf , total_weight , total_frames
107109
108110
109- def train_one_epoch (dataloader , valid_dataloader , model , device , optimizer ,
110- criterion , current_epoch , opts , den_graph , tf_writer , rank ):
111+ def train_one_epoch (dataloader , valid_dataloader , model , device , optimizer , criterion ,
112+ current_epoch , num_epochs , opts , den_graph , tf_writer , rank , dropout_schedule ):
111113 total_objf = 0.
112114 total_weight = 0.
113115 total_frames = 0. # for display only
114116
115117 model .train ()
116-
117118 for batch_idx , batch in enumerate (dataloader ):
119+ data_fraction = (batch_idx + current_epoch *
120+ len (dataloader )) / (len (dataloader ) * num_epochs )
121+ _ , dropout = _get_dropout_proportions (
122+ dropout_schedule , data_fraction )[0 ]
118123 curr_batch_objf , curr_batch_weight , curr_batch_frames = get_objf (
119- batch , model , device , criterion , opts , den_graph , True , optimizer )
124+ batch , model , device , criterion , opts , den_graph , True , optimizer , dropout = dropout )
120125
121126 total_objf += curr_batch_objf
122127 total_weight += curr_batch_weight
@@ -159,6 +164,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
159164 'train/current_batch_average_objf' ,
160165 curr_batch_objf / curr_batch_weight ,
161166 batch_idx + current_epoch * len (dataloader ))
167+
168+ tf_writer .add_scalar (
169+ 'train/current_dropout' ,
170+ dropout ,
171+ batch_idx + current_epoch * len (dataloader ))
162172
163173 state_dict = model .state_dict ()
164174 for key , value in state_dict .items ():
@@ -205,10 +215,10 @@ def main():
205215def process_job (learning_rate , device_id = None , local_rank = None ):
206216 args = get_args ()
207217 if local_rank != None :
208- setup_logger ('{}/train/ logs/log-train-rank-{}' .format (args .dir , local_rank ),
218+ setup_logger ('{}/logs/log-train-rank-{}' .format (args .dir , local_rank ),
209219 args .log_level )
210220 else :
211- setup_logger ('{}/train/ logs/log-train-single-GPU' .format (args .dir ), args .log_level )
221+ setup_logger ('{}/logs/log-train-single-GPU' .format (args .dir ), args .log_level )
212222
213223 logging .info (' ' .join (sys .argv ))
214224
@@ -249,7 +259,6 @@ def process_job(learning_rate, device_id=None, local_rank=None):
249259 opts .leaky_hmm_coefficient = args .leaky_hmm_coefficient
250260
251261 den_graph = chain .DenominatorGraph (fst = den_fst , num_pdfs = args .output_dim )
252-
253262
254263 model = get_chain_model (
255264 feat_dim = args .feat_dim ,
@@ -325,6 +334,9 @@ def process_job(learning_rate, device_id=None, local_rank=None):
325334
326335 if tf_writer :
327336 tf_writer .add_scalar ('learning_rate' , curr_learning_rate , epoch )
337+
338+ if dataloader .sampler and isinstance (dataloader .sampler , DistributedSampler ):
339+ dataloader .sampler .set_epoch (epoch )
328340
329341 objf = train_one_epoch (dataloader = dataloader ,
330342 valid_dataloader = valid_dataloader ,
@@ -333,10 +345,12 @@ def process_job(learning_rate, device_id=None, local_rank=None):
333345 optimizer = optimizer ,
334346 criterion = criterion ,
335347 current_epoch = epoch ,
348+ num_epochs = num_epochs ,
336349 opts = opts ,
337350 den_graph = den_graph ,
338351 tf_writer = tf_writer ,
339- rank = local_rank )
352+ rank = local_rank ,
353+ dropout_schedule = args .dropout_schedule )
340354
341355 if best_objf is None :
342356 best_objf = objf
0 commit comments