@@ -221,11 +221,12 @@ def train(self, train_data_set: DataLoader,
221221 for _ in epoch_it :
222222 step_it = tqdm (train_data_set , desc = "Train iteration" )
223223 avg_loss = 0
224- for s_idx , batch in enumerate (step_it ):
224+ for step , batch in enumerate (step_it ):
225225 self .model .train ()
226226 if distiller :
227227 batch , t_batch = batch [:2 ]
228228 t_batch = tuple (t .to (self .device ) for t in t_batch )
229+ t_logits = distiller .get_teacher_logits (t_batch )
229230 batch = tuple (t .to (self .device ) for t in batch )
230231 inputs = self .batch_mapper (batch )
231232 logits = self .model (** inputs )
@@ -239,7 +240,6 @@ def train(self, train_data_set: DataLoader,
239240
240241 # add distillation loss if activated
241242 if distiller :
242- t_logits = distiller .get_teacher_logits (t_batch )
243243 loss = distiller .distill_loss (loss , logits , t_logits )
244244
245245 loss .backward ()
@@ -251,17 +251,157 @@ def train(self, train_data_set: DataLoader,
251251 global_step += 1
252252 avg_loss += loss .item ()
253253 if global_step % logging_steps == 0 :
254- logger .info (" global_step = %s, average loss = %s" , global_step , avg_loss / s_idx )
254+ if step != 0 :
255+ logger .info (
256+ " global_step = %s, average loss = %s" , global_step , avg_loss / step )
255257 self ._get_eval (dev_data_set , "dev" )
256258 self ._get_eval (test_data_set , "test" )
257259 if save_path is not None and global_step % save_steps == 0 :
258260 self .save_model (save_path )
259261
262+ def train_pseudo (
263+ self , labeled_data_set : DataLoader ,
264+ unlabeled_data_set : DataLoader ,
265+ distiller : TeacherStudentDistill ,
266+ dev_data_set : DataLoader = None ,
267+ test_data_set : DataLoader = None ,
268+ batch_size_l : int = 8 ,
269+ batch_size_ul : int = 8 ,
270+ epochs : int = 100 ,
271+ optimizer = None ,
272+ max_grad_norm : float = 5.0 ,
273+ logging_steps : int = 50 ,
274+ save_steps : int = 100 ,
275+ save_path : str = None ,
276+ save_best : bool = False ):
277+ """
278+ Train a tagging model
279+
280+ Args:
281+ train_data_set (DataLoader): train examples dataloader. If distiller object is
282+ provided train examples should contain a tuple of student/teacher data examples.
283+ dev_data_set (DataLoader, optional): dev examples dataloader. Defaults to None.
284+ test_data_set (DataLoader, optional): test examples dataloader. Defaults to None.
285+ batch_size_l (int, optional): batch size for the labeled dataset. Defaults to 8.
286+ batch_size_ul (int, optional): batch size for the unlabeled dataset. Defaults to 8.
287+ epochs (int, optional): num of epochs to train. Defaults to 100.
288+ optimizer (fn, optional): optimizer function. Defaults to default model optimizer.
289+ max_grad_norm (float, optional): max gradient norm. Defaults to 5.0.
290+ logging_steps (int, optional): number of steps between logging. Defaults to 50.
291+ save_steps (int, optional): number of steps between model saves. Defaults to 100.
292+ save_path (str, optional): model output path. Defaults to None.
293+ save_best (str, optional): wether to save model when result is best on dev set
294+ distiller (TeacherStudentDistill, optional): KD model for training the model using
295+ a teacher model. Defaults to None.
296+ """
297+ if optimizer is None :
298+ optimizer = self .get_optimizer ()
299+ train_batch_size_l = batch_size_l * max (1 , self .n_gpus )
300+ train_batch_size_ul = batch_size_ul * max (1 , self .n_gpus )
301+ logger .info ("***** Running training *****" )
302+ logger .info (" Num labeled examples = %d" , len (labeled_data_set .dataset ))
303+ logger .info (" Num unlabeled examples = %d" , len (unlabeled_data_set .dataset ))
304+ logger .info (" Instantaneous labeled batch size per GPU/CPU = %d" ,
305+ batch_size_l )
306+ logger .info (" Instantaneous unlabeled batch size per GPU/CPU = %d" ,
307+ batch_size_ul )
308+ logger .info (" Total batch size labeled= %d" , train_batch_size_l )
309+ logger .info (" Total batch size unlabeled= %d" , train_batch_size_ul )
310+ global_step = 0
311+ self .model .zero_grad ()
312+ avg_loss = 0
313+ iter_l = iter (labeled_data_set )
314+ iter_ul = iter (unlabeled_data_set )
315+ epoch_l = 0
316+ epoch_ul = 0
317+ s_idx = - 1
318+ best_dev = 0
319+ best_test = 0
320+ while (True ):
321+ logger .info ("labeled epoch=%d, unlabeled epoch=%d" , epoch_l , epoch_ul )
322+ loss_labeled = 0
323+ loss_unlabeled = 0
324+ try :
325+ batch_l = next (iter_l )
326+ s_idx += 1
327+ except StopIteration :
328+ iter_l = iter (labeled_data_set )
329+ epoch_l += 1
330+ batch_l = next (iter_l )
331+ s_idx = 0
332+ avg_loss = 0
333+ try :
334+ batch_ul = next (iter_ul )
335+ except StopIteration :
336+ iter_ul = iter (unlabeled_data_set )
337+ epoch_ul += 1
338+ batch_ul = next (iter_ul )
339+ if epoch_ul > epochs :
340+ logger .info ("Done" )
341+ return
342+ self .model .train ()
343+ batch_l , t_batch_l = batch_l [:2 ]
344+ batch_ul , t_batch_ul = batch_ul [:2 ]
345+ t_batch_l = tuple (t .to (self .device ) for t in t_batch_l )
346+ t_batch_ul = tuple (t .to (self .device ) for t in t_batch_ul )
347+ t_logits = distiller .get_teacher_logits (t_batch_l )
348+ t_logits_ul = distiller .get_teacher_logits (t_batch_ul )
349+ batch_l = tuple (t .to (self .device ) for t in batch_l )
350+ batch_ul = tuple (t .to (self .device ) for t in batch_ul )
351+ inputs = self .batch_mapper (batch_l )
352+ inputs_ul = self .batch_mapper (batch_ul )
353+ logits = self .model (** inputs )
354+ logits_ul = self .model (** inputs_ul )
355+ t_labels = torch .argmax (F .log_softmax (t_logits_ul , dim = 2 ), dim = 2 )
356+ if self .use_crf :
357+ loss_labeled = - 1.0 * self .crf (
358+ logits , inputs ['labels' ], mask = inputs ['mask' ] != 0.0 )
359+ loss_unlabeled = - 1.0 * self .crf (
360+ logits_ul , t_labels , mask = inputs_ul ['mask' ] != 0.0 )
361+ else :
362+ loss_fn = CrossEntropyLoss (ignore_index = 0 )
363+ loss_labeled = loss_fn (logits .view (- 1 , self .num_labels ), inputs ['labels' ].view (- 1 ))
364+ loss_unlabeled = loss_fn (logits_ul .view (- 1 , self .num_labels ), t_labels .view (- 1 ))
365+
366+ if self .n_gpus > 1 :
367+ loss_labeled = loss_labeled .mean ()
368+ loss_unlabeled = loss_unlabeled .mean ()
369+
370+ # add distillation loss
371+ loss_labeled = distiller .distill_loss (loss_labeled , logits , t_logits )
372+ loss_unlabeled = distiller .distill_loss (loss_unlabeled , logits_ul , t_logits_ul )
373+
374+ # sum labeled and unlabeled losses
375+ loss = loss_labeled + loss_unlabeled
376+ loss .backward ()
377+ torch .nn .utils .clip_grad_norm_ (self .model .parameters (), max_grad_norm )
378+ optimizer .step ()
379+ # self.model.zero_grad()
380+ optimizer .zero_grad ()
381+ global_step += 1
382+ avg_loss += loss .item ()
383+ if global_step % logging_steps == 0 :
384+ if s_idx != 0 :
385+ logger .info (
386+ " global_step = %s, average loss = %s" , global_step , avg_loss / s_idx )
387+ dev = self ._get_eval (dev_data_set , "dev" )
388+ test = self ._get_eval (test_data_set , "test" )
389+ if dev > best_dev :
390+ best_dev = dev
391+ best_test = test
392+ if save_path is not None and save_best :
393+ self .save_model (save_path )
394+ logger .info ("Best result: dev= %s, test= %s" , str (best_dev ), str (best_test ))
395+ if save_path is not None and global_step % save_steps == 0 :
396+ self .save_model (save_path )
397+
260398 def _get_eval (self , ds , set_name ):
261399 if ds is not None :
262400 logits , out_label_ids = self .evaluate (ds )
263401 res = self .evaluate_predictions (logits , out_label_ids )
264402 logger .info (" {} set F1 = {}" .format (set_name , res ['f1' ]))
403+ return res ['f1' ]
404+ return None
265405
266406 def to (self , device = 'cpu' , n_gpus = 0 ):
267407 """
0 commit comments