Skip to content

Commit 8e94a97

Browse files
authored
Fix batchify function of token classification task of FasterErnie. (#1583)
* fix token cls * fix faster tokenizer pad_to_max_seq_bug * Fix batch_max_seq
1 parent 2197402 commit 8e94a97

File tree

1 file changed

+15
-7
lines changed
  • examples/faster/faster_ernie/token_cls

1 file changed

+15
-7
lines changed

examples/faster/faster_ernie/token_cls/train.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,26 @@ def evaluate(model, criterion, metric, data_loader, label_num):
7676

7777
def batchify_fn(batch, no_entity_id, ignore_label=-100, max_seq_len=512):
7878
texts, labels, seq_lens = [], [], []
79+
# 2 for [CLS] and [SEP]
80+
batch_max_seq = max([len(example["tokens"]) for example in batch]) + 2
81+
# Truncation: Handle max sequence length
82+
# If max_seq_len == 0, then do nothing and keep the real length.
83+
# If max_seq_len > 0 and
84+
# all the input sequence len is over the max_seq_len,
85+
# then we truncate it.
86+
if max_seq_len > 0:
87+
batch_max_seq = min(batch_max_seq, max_seq_len)
7988
for example in batch:
8089
texts.append("".join(example["tokens"]))
81-
# 2 for [CLS] and [SEP]
82-
seq_lens.append(len(example["tokens"]) + 2)
8390
label = example["labels"]
84-
if len(label) > max_seq_len - 2:
85-
label = label[:(max_seq_len - 2)]
91+
# 2 for [CLS] and [SEP]
92+
if len(label) > batch_max_seq - 2:
93+
label = label[:(batch_max_seq - 2)]
8694
label = [no_entity_id] + label + [no_entity_id]
87-
if len(label) < max_seq_len:
88-
label += [ignore_label] * (max_seq_len - len(label))
95+
seq_lens.append(len(label))
96+
if len(label) < batch_max_seq:
97+
label += [ignore_label] * (batch_max_seq - len(label))
8998
labels.append(label)
90-
9199
labels = np.array(labels, dtype="int64")
92100
seq_lens = np.array(seq_lens)
93101
return texts, labels, seq_lens

0 commit comments

Comments
 (0)