3333from model import get_chain_model
3434from options import get_args
3535
36- def get_validation_objf ( dataloader , model , device , criterion , opts , den_graph ):
36+ def get_objf ( batch , model , device , criterion , opts , den_graph , training , optimizer = None ):
3737 total_objf = 0.
3838 total_weight = 0.
3939 total_frames = 0. # for display only
40+
41+ key_list , feature_list , supervision_list = batch
42+ assert len (key_list ) == len (feature_list ) == len (supervision_list )
43+ batch_size = len (key_list )
44+ for n in range (batch_size ):
45+ feats = feature_list [n ]
46+ assert feats .ndim == 3
47+
48+ # at this point, feats is [N, T, C]
49+ feats = feats .to (device )
50+ if training :
51+ nnet_output , xent_output = model (feats )
52+ else :
53+ with torch .no_grad ():
54+ nnet_output , xent_output = model (feats )
4055
41- model .eval ()
56+ # at this point, nnet_output is: [N, T, C]
57+ # refer to kaldi/src/chain/chain-training.h
58+ # the output should be organized as
59+ # [all sequences for frame 0]
60+ # [all sequences for frame 1]
61+ # [etc.]
62+ nnet_output = nnet_output .permute (1 , 0 , 2 )
63+ # at this point, nnet_output is: [T, N, C]
64+ nnet_output = nnet_output .contiguous ().view (- 1 ,
65+ nnet_output .shape [- 1 ])
66+
67+ # at this point, xent_output is: [N, T, C]
68+ xent_output = xent_output .permute (1 , 0 , 2 )
69+ # at this point, xent_output is: [T, N, C]
70+ xent_output = xent_output .contiguous ().view (- 1 ,
71+ xent_output .shape [- 1 ])
72+ objf_l2_term_weight = criterion (opts , den_graph ,
73+ supervision_list [n ], nnet_output ,
74+ xent_output )
75+ objf = objf_l2_term_weight [0 ]
76+ if training :
77+ optimizer .zero_grad ()
78+ objf .backward ()
79+ clip_grad_value_ (model .parameters (), 5.0 )
80+ optimizer .step ()
4281
43- for batch_idx , batch in enumerate (dataloader ):
44- key_list , feature_list , supervision_list = batch
82+ objf_l2_term_weight = objf_l2_term_weight .detach ().cpu ()
4583
46- assert len (key_list ) == len (feature_list ) == len (supervision_list )
47- batch_size = len (key_list )
84+ total_objf += objf_l2_term_weight [0 ].item ()
85+ total_weight += objf_l2_term_weight [2 ].item ()
86+ num_frames = nnet_output .shape [0 ]
87+ total_frames += num_frames
4888
49- for n in range (batch_size ):
50- feats = feature_list [n ]
51- assert feats .ndim == 3
89+ return total_objf , total_weight , total_frames
5290
53- # at this point, feats is [N, T, C]
54- feats = feats .to (device )
5591
56- with torch .no_grad ():
57- nnet_output , xent_output = model (feats )
92+ def get_validation_objf (dataloader , model , device , criterion , opts , den_graph ):
93+ total_objf = 0.
94+ total_weight = 0.
95+ total_frames = 0. # for display only
96+
97+ model .eval ()
5898
59- # at this point, nnet_output is: [N, T, C]
60- # refer to kaldi/src/chain/chain-training.h
61- # the output should be organized as
62- # [all sequences for frame 0]
63- # [all sequences for frame 1]
64- # [etc.]
65- nnet_output = nnet_output .permute (1 , 0 , 2 )
66- # at this point, nnet_output is: [T, N, C]
67- nnet_output = nnet_output .contiguous ().view (- 1 ,
68- nnet_output .shape [- 1 ])
69-
70- # at this point, xent_output is: [N, T, C]
71- xent_output = xent_output .permute (1 , 0 , 2 )
72- # at this point, xent_output is: [T, N, C]
73- xent_output = xent_output .contiguous ().view (- 1 ,
74- xent_output .shape [- 1 ])
75- objf_l2_term_weight = criterion (opts , den_graph ,
76- supervision_list [n ], nnet_output ,
77- xent_output )
78- objf = objf_l2_term_weight [0 ]
79-
80- objf_l2_term_weight = objf_l2_term_weight .cpu ()
81-
82- total_objf += objf_l2_term_weight [0 ].item ()
83- total_weight += objf_l2_term_weight [2 ].item ()
84-
85- num_frames = nnet_output .shape [0 ]
86- total_frames += num_frames
99+ for batch_idx , batch in enumerate (dataloader ):
100+ objf , weight , frames = get_objf (
101+ batch , model , device , criterion , opts , den_graph , False )
102+ total_objf += objf
103+ total_weight += weight
104+ total_frames += frames
87105
88106 return total_objf , total_weight , total_frames
89107
90108
91109def train_one_epoch (dataloader , valid_dataloader , model , device , optimizer ,
92110 criterion , current_epoch , opts , den_graph , tf_writer , rank ):
93- model .train ()
94-
95111 total_objf = 0.
96112 total_weight = 0.
97113 total_frames = 0. # for display only
98114
99- for batch_idx , batch in enumerate (dataloader ):
100- key_list , feature_list , supervision_list = batch
101- assert len (key_list ) == len (feature_list ) == len (supervision_list )
102- batch_size = len (key_list )
103- for n in range (batch_size ):
104- feats = feature_list [n ]
105- assert feats .ndim == 3
106-
107- # at this point, feats is [N, T, C]
108- feats = feats .to (device )
109- nnet_output , xent_output = model (feats )
110-
111- # at this point, nnet_output is: [N, T, C]
112- # refer to kaldi/src/chain/chain-training.h
113- # the output should be organized as
114- # [all sequences for frame 0]
115- # [all sequences for frame 1]
116- # [etc.]
117- nnet_output = nnet_output .permute (1 , 0 , 2 )
118- # at this point, nnet_output is: [T, N, C]
119- nnet_output = nnet_output .contiguous ().view (- 1 ,
120- nnet_output .shape [- 1 ])
121-
122- # at this point, xent_output is: [N, T, C]
123- xent_output = xent_output .permute (1 , 0 , 2 )
124- # at this point, xent_output is: [T, N, C]
125- xent_output = xent_output .contiguous ().view (- 1 ,
126- xent_output .shape [- 1 ])
127- objf_l2_term_weight = criterion (opts , den_graph ,
128- supervision_list [n ], nnet_output ,
129- xent_output )
130- objf = objf_l2_term_weight [0 ]
131- optimizer .zero_grad ()
132- objf .backward ()
133-
134- clip_grad_value_ (model .parameters (), 5.0 )
135-
136- optimizer .step ()
115+ model .train ()
137116
138- objf_l2_term_weight = objf_l2_term_weight .detach ().cpu ()
117+ for batch_idx , batch in enumerate (dataloader ):
118+ curr_batch_objf , curr_batch_weight , curr_batch_frames = get_objf (
119+ batch , model , device , criterion , opts , den_graph , True , optimizer )
139120
140- total_objf += objf_l2_term_weight [0 ].item ()
141- total_weight += objf_l2_term_weight [2 ].item ()
142- num_frames = nnet_output .shape [0 ]
143- total_frames += num_frames
121+ total_objf += curr_batch_objf
122+ total_weight += curr_batch_weight
123+ total_frames += curr_batch_frames
144124
145125 if batch_idx % 100 == 0 :
146126 logging .info (
@@ -150,8 +130,8 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
150130 device .index , batch_idx , len (dataloader ),
151131 float (batch_idx ) / len (dataloader ) * 100 ,
152132 total_objf / total_weight , total_frames ,
153- objf_l2_term_weight [ 0 ]. item () /
154- objf_l2_term_weight [ 2 ]. item (), num_frames , current_epoch ))
133+ curr_batch_objf / curr_batch_weight ,
134+ curr_batch_frames , current_epoch ))
155135
156136 if valid_dataloader and batch_idx % 1000 == 0 :
157137 total_valid_objf , total_valid_weight , total_valid_frames = get_validation_objf (
@@ -161,7 +141,6 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
161141 criterion = criterion ,
162142 opts = opts ,
163143 den_graph = den_graph )
164-
165144 model .train ()
166145
167146 logging .info (
@@ -178,7 +157,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
178157 batch_idx + current_epoch * len (dataloader ))
179158 tf_writer .add_scalar (
180159 'train/current_batch_average_objf' ,
181- objf_l2_term_weight [ 0 ]. item () / objf_l2_term_weight [ 2 ]. item () ,
160+ curr_batch_objf / curr_batch_weight ,
182161 batch_idx + current_epoch * len (dataloader ))
183162
184163 state_dict = model .state_dict ()
@@ -206,20 +185,24 @@ def main():
206185 if args .multiple_machine :
207186 # Suppose we have submitted multiple jobs with SGE (Sun Grid Engine)
208187 local_rank = int (os .environ ['SGE_TASK_ID' ]) - 1
209- process_job (learning_rate , local_rank )
188+ process_job (learning_rate , local_rank = local_rank )
210189 else :
211190 proc = []
191+ if args .device_ids != None :
192+ assert len (args .device_ids ) >= args .world_size
212193 for i in range (args .world_size ):
213- p = Process (target = process_job , args = (learning_rate , i ))
194+ device_id = None if args .device_ids == None else args .device_ids [i ]
195+ p = Process (target = process_job , args = (learning_rate , device_id , i ))
214196 proc .append (p )
215197 p .start ()
216198 for p in proc :
217199 p .join ()
218200 else :
219- process_job (args .learning_rate )
201+ device_id = None if args .device_ids == None else args .device_ids [0 ]
202+ process_job (args .learning_rate , device_id )
220203
221204
222- def process_job (learning_rate , local_rank = None ):
205+ def process_job (learning_rate , device_id = None , local_rank = None ):
223206 args = get_args ()
224207 if local_rank != None :
225208 setup_logger ('{}/train/logs/log-train-rank-{}' .format (args .dir , local_rank ),
@@ -233,12 +216,13 @@ def process_job(learning_rate, local_rank=None):
233216 logging .error ('No GPU detected!' )
234217 sys .exit (- 1 )
235218
236- devices = allocate_gpu_devices (1 )
237- if len (devices ) < 1 :
238- logging .error ('Allocate GPU failed!' )
239- sys .exit (- 1 )
219+ if device_id == None :
220+ devices = allocate_gpu_devices (1 )
221+ if len (devices ) < 1 :
222+ logging .error ('Allocate GPU failed!' )
223+ sys .exit (- 1 )
224+ device_id = devices [0 ][0 ]
240225
241- device_id = devices [0 ][0 ]
242226 logging .info ('device: {}' .format (device_id ))
243227
244228 if args .use_ddp :
0 commit comments