@@ -76,18 +76,26 @@ def evaluate(model, criterion, metric, data_loader, label_num):
76
76
77
77
def batchify_fn (batch , no_entity_id , ignore_label = - 100 , max_seq_len = 512 ):
78
78
texts , labels , seq_lens = [], [], []
79
+ # 2 for [CLS] and [SEP]
80
+ batch_max_seq = max ([len (example ["tokens" ]) for example in batch ]) + 2
81
+ # Truncation: Handle max sequence length
82
+ # If max_seq_len == 0, then do nothing and keep the real length.
83
+ # If max_seq_len > 0 and
84
+ # all the input sequence len is over the max_seq_len,
85
+ # then we truncate it.
86
+ if max_seq_len > 0 :
87
+ batch_max_seq = min (batch_max_seq , max_seq_len )
79
88
for example in batch :
80
89
texts .append ("" .join (example ["tokens" ]))
81
- # 2 for [CLS] and [SEP]
82
- seq_lens .append (len (example ["tokens" ]) + 2 )
83
90
label = example ["labels" ]
84
- if len (label ) > max_seq_len - 2 :
85
- label = label [:(max_seq_len - 2 )]
91
+ # 2 for [CLS] and [SEP]
92
+ if len (label ) > batch_max_seq - 2 :
93
+ label = label [:(batch_max_seq - 2 )]
86
94
label = [no_entity_id ] + label + [no_entity_id ]
87
- if len (label ) < max_seq_len :
88
- label += [ignore_label ] * (max_seq_len - len (label ))
95
+ seq_lens .append (len (label ))
96
+ if len (label ) < batch_max_seq :
97
+ label += [ignore_label ] * (batch_max_seq - len (label ))
89
98
labels .append (label )
90
-
91
99
labels = np .array (labels , dtype = "int64" )
92
100
seq_lens = np .array (seq_lens )
93
101
return texts , labels , seq_lens
0 commit comments