1313import pickle as pkl
1414
1515from utils .common_utils import dump_to_pkl , load_from_pkl , get_param_num , get_trainable_param_num , \
16- transfer_to_gpu , transform_params2tensors
16+ transfer_to_gpu , transform_params2tensors , get_layer_class
1717from utils .philly_utils import HDFSDirectTransferer , open_and_move , convert_to_tmppath , \
1818 convert_to_hdfspath , move_from_local_to_hdfs
1919from Model import Model
2424from core .LRScheduler import LRScheduler
2525from settings import ProblemTypes
2626from block_zoo import Linear
27+ from block_zoo import CRF
28+ from losses .CRFLoss import CRFLoss
2729
2830
2931class LearningMachine (object ):
@@ -169,6 +171,8 @@ def train(self, optimizer, loss_fn):
169171 (not self .model .module .layers [tmp_output_layer_id ].layer_conf .last_hidden_softmax ):
170172 logits_softmax [tmp_output_layer_id ] = nn .functional .softmax (
171173 logits [tmp_output_layer_id ], dim = - 1 )
174+ elif isinstance (get_layer_class (self .model , tmp_output_layer_id ), CRF ):
175+ pass
172176 else :
173177 logits_softmax [tmp_output_layer_id ] = logits [tmp_output_layer_id ]
174178 else :
@@ -177,6 +181,8 @@ def train(self, optimizer, loss_fn):
177181 (not self .model .layers [tmp_output_layer_id ].layer_conf .last_hidden_softmax ):
178182 logits_softmax [tmp_output_layer_id ] = nn .functional .softmax (
179183 logits [tmp_output_layer_id ], dim = - 1 )
184+ elif isinstance (get_layer_class (self .model , tmp_output_layer_id ), CRF ):
185+ pass
180186 else :
181187 logits_softmax [tmp_output_layer_id ] = logits [tmp_output_layer_id ]
182188
@@ -194,8 +200,9 @@ def train(self, optimizer, loss_fn):
194200 prediction_scores_all = None
195201 elif ProblemTypes [self .problem .problem_type ] == ProblemTypes .sequence_tagging :
196202 logits = list (logits .values ())[0 ]
197- logits_softmax = list (logits_softmax .values ())[0 ]
198- assert len (logits_softmax .shape ) == 3 , 'The dimension of your output is %s, but we need [batch_size*GPUs, sequence length, representation dim]' % (str (list (logits_softmax .shape )), )
203+ if not isinstance (get_layer_class (self .model , tmp_output_layer_id ), CRF ):
204+ logits_softmax = list (logits_softmax .values ())[0 ]
205+ assert len (logits_softmax .shape ) == 3 , 'The dimension of your output is %s, but we need [batch_size*GPUs, sequence length, representation dim]' % (str (list (logits_softmax .shape )), )
199206 prediction_scores = None
200207 prediction_scores_all = None
201208 elif ProblemTypes [self .problem .problem_type ] == ProblemTypes .regression :
@@ -214,16 +221,25 @@ def train(self, optimizer, loss_fn):
214221 if ProblemTypes [self .problem .problem_type ] == ProblemTypes .sequence_tagging :
215222 # Transform output shapes for metric evaluation
216223 # for seq_tag_f1 metric
217- prediction_indices = logits_softmax .data .max (2 )[1 ].cpu ().numpy () # [batch_size, seq_len]
218- streaming_recoder .record_one_row ([self .problem .decode (prediction_indices , length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ()),
219- prediction_scores , self .problem .decode (target_batches [i ][self .conf .answer_column_name [0 ]],
220- length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ())], keep_dim = False )
224+ if isinstance (get_layer_class (self .model , tmp_output_layer_id ), CRF ):
225+ forward_score , scores , masks , tag_seq , transitions , layer_conf = logits
226+ prediction_indices = tag_seq .cpu ().numpy ()
227+ streaming_recoder .record_one_row ([self .problem .decode (prediction_indices , length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ()),
228+ prediction_scores , self .problem .decode (
229+ target_batches [i ][self .conf .answer_column_name [0 ]],
230+ length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ())], keep_dim = False )
221231
222- # pytorch's CrossEntropyLoss only support this
223- logits_flat [self .conf .output_layer_id [0 ]] = logits .view (- 1 , logits .size (2 )) # [batch_size * seq_len, # of tags]
224- #target_batches[i] = target_batches[i].view(-1) # [batch_size * seq_len]
225- # [batch_size * seq_len]
226- target_batches [i ][self .conf .answer_column_name [0 ]] = target_batches [i ][self .conf .answer_column_name [0 ]].reshape (- 1 )
232+ else :
233+ prediction_indices = logits_softmax .data .max (2 )[1 ].cpu ().numpy () # [batch_size, seq_len]
234+ # pytorch's CrossEntropyLoss only support this
235+ logits_flat [self .conf .output_layer_id [0 ]] = logits .view (- 1 , logits .size (2 )) # [batch_size * seq_len, # of tags]
236+ streaming_recoder .record_one_row ([self .problem .decode (prediction_indices , length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ()),
237+ prediction_scores , self .problem .decode (
238+ target_batches [i ][self .conf .answer_column_name [0 ]],
239+ length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ())], keep_dim = False )
240+
241+ target_batches [i ][self .conf .answer_column_name [0 ]] = target_batches [i ][
242+ self .conf .answer_column_name [0 ]].reshape (- 1 )
227243
228244 elif ProblemTypes [self .problem .problem_type ] == ProblemTypes .classification :
229245 prediction_indices = logits_softmax .detach ().max (1 )[1 ].cpu ().numpy ()
@@ -260,7 +276,10 @@ def train(self, optimizer, loss_fn):
260276 for single_target in self .conf .answer_column_name :
261277 if isinstance (target_batches [i ][single_target ], torch .Tensor ):
262278 target_batches [i ][single_target ] = transfer_to_gpu (target_batches [i ][single_target ])
263- loss = loss_fn (logits_flat , target_batches [i ])
279+ if isinstance (loss_fn .loss_fn [0 ], CRFLoss ):
280+ loss = loss_fn .loss_fn [0 ](forward_score , scores , masks , list (target_batches [i ].values ())[0 ], transitions , layer_conf )
281+ else :
282+ loss = loss_fn (logits_flat , target_batches [i ])
264283
265284 all_costs .append (loss .item ())
266285 optimizer .zero_grad ()
@@ -297,7 +316,7 @@ def train(self, optimizer, loss_fn):
297316
298317 if torch .cuda .device_count () > 1 :
299318 logging .info ("Epoch %d batch idx: %d; lr: %f; since last log, loss=%f; %s" % \
300- (epoch , i * torch .cuda .device_count (), lr_scheduler .get_lr (), np .mean (all_costs ), result ))
319+ (epoch , i * torch .cuda .device_count (), lr_scheduler .get_lr (), np .sum (all_costs ), result ))
301320 else :
302321 logging .info ("Epoch %d batch idx: %d; lr: %f; since last log, loss=%f; %s" % \
303322 (epoch , i , lr_scheduler .get_lr (), np .mean (all_costs ), result ))
@@ -473,18 +492,29 @@ def evaluate(self, data, length, target, input_types, evaluator,
473492 logits_flat = {}
474493 if ProblemTypes [self .problem .problem_type ] == ProblemTypes .sequence_tagging :
475494 logits = list (logits .values ())[0 ]
476- logits_softmax = list (logits_softmax .values ())[0 ]
477- # Transform output shapes for metric evaluation
478- # for seq_tag_f1 metric
479- prediction_indices = logits_softmax .data .max (2 )[1 ].cpu ().numpy () # [batch_size, seq_len]
480- streaming_recoder .record_one_row (
481- [self .problem .decode (prediction_indices , length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ()), prediction_pos_scores ,
482- self .problem .decode (target_batches [i ], length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ())], keep_dim = False )
495+ if isinstance (get_layer_class (self .model , tmp_output_layer_id ), CRF ):
496+ forward_score , scores , masks , tag_seq , transitions , layer_conf = logits
497+ prediction_indices = tag_seq .cpu ().numpy ()
498+ streaming_recoder .record_one_row (
499+ [self .problem .decode (prediction_indices , length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ()),
500+ prediction_pos_scores ,
501+ self .problem .decode (target_batches [i ], length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ())],
502+ keep_dim = False )
503+ else :
504+ logits_softmax = list (logits_softmax .values ())[0 ]
505+ # Transform output shapes for metric evaluation
506+ # for seq_tag_f1 metric
507+ prediction_indices = logits_softmax .data .max (2 )[1 ].cpu ().numpy () # [batch_size, seq_len]
508+ # pytorch's CrossEntropyLoss only support this
509+ logits_flat [self .conf .output_layer_id [0 ]] = logits .view (- 1 , logits .size (2 )) # [batch_size * seq_len, # of tags]
510+ streaming_recoder .record_one_row (
511+ [self .problem .decode (prediction_indices , length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ()),
512+ prediction_pos_scores ,
513+ self .problem .decode (target_batches [i ], length_batches [i ]['target' ][self .conf .answer_column_name [0 ]].numpy ())],
514+ keep_dim = False )
483515
484- # pytorch's CrossEntropyLoss only support this
485- logits_flat [self .conf .output_layer_id [0 ]] = logits .view (- 1 , logits .size (2 )) # [batch_size * seq_len, # of tags]
486- #target_batches[i] = target_batches[i].view(-1) # [batch_size * seq_len]
487- target_batches [i ][self .conf .answer_column_name [0 ]] = target_batches [i ][self .conf .answer_column_name [0 ]].reshape (- 1 ) # [batch_size * seq_len]
516+ target_batches [i ][self .conf .answer_column_name [0 ]] = target_batches [i ][
517+ self .conf .answer_column_name [0 ]].reshape (- 1 ) # [batch_size * seq_len]
488518
489519 if to_predict :
490520 prediction_batch = self .problem .decode (prediction_indices , length_batches [i ][key_random ].numpy ())
@@ -547,8 +577,13 @@ def evaluate(self, data, length, target, input_types, evaluator,
547577 predict_stream_recoder .record_one_row ([prediction ])
548578
549579 if to_predict :
550- logits_len = len (list (logits .values ())[0 ]) \
551- if ProblemTypes [self .problem .problem_type ] == ProblemTypes .mrc else len (logits )
580+ if ProblemTypes [self .problem .problem_type ] == ProblemTypes .mrc :
581+ logits_len = len (list (logits .values ())[0 ])
582+ elif ProblemTypes [self .problem .problem_type ] == ProblemTypes .sequence_tagging and isinstance (get_layer_class (self .model , tmp_output_layer_id ), CRF ):
583+ # for sequence_tagging task, logits is tuple type which index 3 is tag_seq [batch_size*seq_len]
584+ logits_len = logits [3 ].size (0 )
585+ else :
586+ logits_len = len (logits )
552587 for sample_idx in range (logits_len ):
553588 while True :
554589 sample = fin .readline ().rstrip ()
@@ -564,7 +599,10 @@ def evaluate(self, data, length, target, input_types, evaluator,
564599 for single_target in self .conf .answer_column_name :
565600 if isinstance (target_batches [i ][single_target ], torch .Tensor ):
566601 target_batches [i ][single_target ] = transfer_to_gpu (target_batches [i ][single_target ])
567- loss = loss_fn (logits_flat , target_batches [i ])
602+ if isinstance (loss_fn .loss_fn [0 ], CRFLoss ):
603+ loss = loss_fn .loss_fn [0 ](forward_score , scores , masks , list (target_batches [i ].values ())[0 ], transitions , layer_conf )
604+ else :
605+ loss = loss_fn (logits_flat , target_batches [i ])
568606 loss_recoder .record ('loss' , loss .item ())
569607
570608 del loss , logits , logits_softmax , logits_flat
@@ -686,9 +724,14 @@ def predict(self, predict_data_path, output_path, file_columns, predict_fields=[
686724
687725 if ProblemTypes [self .problem .problem_type ] == ProblemTypes .sequence_tagging :
688726 logits = list (logits .values ())[0 ]
689- logits_softmax = list (logits_softmax .values ())[0 ]
690- # Transform output shapes for metric evaluation
691- prediction_indices = logits_softmax .data .max (2 )[1 ].cpu ().numpy () # [batch_size, seq_len]
727+ if isinstance (get_layer_class (self .model , tmp_output_layer_id ), CRF ):
728+ forward_score , scores , masks , tag_seq , transitions , layer_conf = logits
729+ prediction_indices = tag_seq .cpu ().numpy ()
730+ else :
731+ logits_softmax = list (logits_softmax .values ())[0 ]
732+ # Transform output shapes for metric evaluation
733+ # for seq_tag_f1 metric
734+ prediction_indices = logits_softmax .data .max (2 )[1 ].cpu ().numpy () # [batch_size, seq_len]
692735 prediction_batch = self .problem .decode (prediction_indices , length_batches [i ][key_random ].numpy ())
693736 for prediction_sample in prediction_batch :
694737 streaming_recoder .record ('prediction' , " " .join (prediction_sample ))
0 commit comments