Skip to content

Commit d0d2067

Browse files
authored
add new datacollator upgrade examples (#1816)
* upgrade predict_glue.py * test doc api * revert * update collator * revert dureader_robus * minor fix * add glue sample * stash * stash * upgrade examples for new datacollator * fix squad sample * add token classification collator * fix doc * fix collator for squad * minor fix * upgrade msra_ner predict * add some doc * add more check
1 parent 1df6643 commit d0d2067

File tree

6 files changed

+285
-71
lines changed

6 files changed

+285
-71
lines changed

examples/information_extraction/msra_ner/predict.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import paddlenlp as ppnlp
2828
from datasets import load_dataset
29-
from paddlenlp.data import Stack, Tuple, Pad, Dict
29+
from paddlenlp.data import DataCollatorForTokenClassification
3030
from paddlenlp.transformers import BertForTokenClassification, BertTokenizer
3131

3232
parser = argparse.ArgumentParser()
@@ -75,6 +75,7 @@ def do_predict(args):
7575
# Create dataset, tokenizer and dataloader.
7676
train_examples, predict_examples = load_dataset(
7777
'msra_ner', split=('train', 'test'))
78+
column_names = train_examples.column_names
7879
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
7980

8081
label_list = train_examples.features['ner_tags'].feature.names
@@ -104,17 +105,14 @@ def tokenize_and_align_labels(examples):
104105
return tokenized_inputs
105106

106107
ignore_label = -100
107-
batchify_fn = lambda samples, fn=Dict({
108-
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
109-
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
110-
'seq_len': Stack(),
111-
'labels': Pad(axis=0, pad_val=ignore_label) # label
112-
}): fn(samples)
108+
batchify_fn = DataCollatorForTokenClassification(tokenizer)
113109

114110
id2label = dict(enumerate(label_list))
115111

116112
predict_examples = predict_examples.select(range(len(predict_examples) - 1))
117-
predict_ds = predict_examples.map(tokenize_and_align_labels, batched=True)
113+
predict_ds = predict_examples.map(tokenize_and_align_labels,
114+
batched=True,
115+
remove_columns=column_names)
118116
predict_data_loader = DataLoader(
119117
dataset=predict_ds,
120118
collate_fn=batchify_fn,
@@ -133,11 +131,10 @@ def tokenize_and_align_labels(examples):
133131
pred_list = []
134132
len_list = []
135133
for step, batch in enumerate(predict_data_loader):
136-
input_ids, token_type_ids, length, labels = batch
137-
logits = model(input_ids, token_type_ids)
134+
logits = model(batch['input_ids'], batch['token_type_ids'])
138135
pred = paddle.argmax(logits, axis=-1)
139136
pred_list.append(pred.numpy())
140-
len_list.append(length.numpy())
137+
len_list.append(batch['seq_len'].numpy())
141138

142139
preds = parse_decodes(predict_examples, id2label, pred_list, len_list)
143140

examples/information_extraction/msra_ner/train.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from paddlenlp.transformers import BertForTokenClassification, BertTokenizer
3131
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
3232
from paddlenlp.transformers import ErnieCtmForTokenClassification, ErnieCtmTokenizer
33-
from paddlenlp.data import Stack, Tuple, Pad, Dict
33+
from paddlenlp.data import DataCollatorForTokenClassification
3434
from paddlenlp.utils.log import logger
3535

3636
MODEL_CLASSES = {
@@ -68,13 +68,12 @@ def evaluate(model, loss_fct, metric, data_loader, label_num, mode="valid"):
6868
metric.reset()
6969
avg_loss, precision, recall, f1_score = 0, 0, 0, 0
7070
for batch in data_loader:
71-
input_ids, token_type_ids, length, labels = batch
72-
logits = model(input_ids, token_type_ids)
73-
loss = loss_fct(logits, labels)
71+
logits = model(batch['input_ids'], batch['token_type_ids'])
72+
loss = loss_fct(logits, batch['labels'])
7473
avg_loss = paddle.mean(loss)
7574
preds = logits.argmax(axis=2)
7675
num_infer_chunks, num_label_chunks, num_correct_chunks = metric.compute(
77-
length, preds, labels)
76+
batch['seq_len'], preds, batch['labels'])
7877
metric.update(num_infer_chunks.numpy(),
7978
num_label_chunks.numpy(), num_correct_chunks.numpy())
8079
precision, recall, f1_score = metric.accumulate()
@@ -125,16 +124,14 @@ def tokenize_and_align_labels(examples):
125124
return tokenized_inputs
126125

127126
train_ds = train_ds.select(range(len(train_ds) - 1))
128-
train_ds = train_ds.map(tokenize_and_align_labels, batched=True)
127+
column_names = train_ds.column_names
128+
train_ds = train_ds.map(tokenize_and_align_labels,
129+
batched=True,
130+
remove_columns=column_names)
129131

130132
ignore_label = -100
131133

132-
batchify_fn = lambda samples, fn=Dict({
133-
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int32'), # input
134-
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int32'), # segment
135-
'seq_len': Stack(dtype='int64'), # seq_len
136-
'labels': Pad(axis=0, pad_val=ignore_label, dtype='int64') # label
137-
}): fn(samples)
134+
batchify_fn = DataCollatorForTokenClassification(tokenizer, ignore_label)
138135

139136
train_batch_sampler = paddle.io.DistributedBatchSampler(
140137
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True)
@@ -148,7 +145,9 @@ def tokenize_and_align_labels(examples):
148145

149146
test_ds = raw_datasets['test']
150147
test_ds = test_ds.select(range(len(test_ds) - 1))
151-
test_ds = test_ds.map(tokenize_and_align_labels, batched=True)
148+
test_ds = test_ds.map(tokenize_and_align_labels,
149+
batched=True,
150+
remove_columns=column_names)
152151

153152
test_data_loader = DataLoader(
154153
dataset=test_ds,
@@ -160,7 +159,9 @@ def tokenize_and_align_labels(examples):
160159
if args.dataset == "peoples_daily_ner":
161160
dev_ds = raw_datasets['validation']
162161
dev_ds = dev_ds.select(range(len(dev_ds) - 1))
163-
dev_ds = dev_ds.map(tokenize_and_align_labels, batched=True)
162+
dev_ds = dev_ds.map(tokenize_and_align_labels,
163+
batched=True,
164+
remove_columns=column_names)
164165

165166
dev_data_loader = DataLoader(
166167
dataset=dev_ds,
@@ -205,9 +206,8 @@ def tokenize_and_align_labels(examples):
205206
for epoch in range(args.num_train_epochs):
206207
for step, batch in enumerate(train_data_loader):
207208
global_step += 1
208-
input_ids, token_type_ids, _, labels = batch
209-
logits = model(input_ids, token_type_ids)
210-
loss = loss_fct(logits, labels)
209+
logits = model(batch['input_ids'], batch['token_type_ids'])
210+
loss = loss_fct(logits, batch['labels'])
211211
avg_loss = paddle.mean(loss)
212212
if global_step % args.logging_steps == 0:
213213
print(

examples/language_model/bert/run_glue.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from paddle.metric import Metric, Accuracy, Precision, Recall
2828

2929
from datasets import load_dataset
30-
from paddlenlp.data import Stack, Tuple, Pad, Dict
30+
from paddlenlp.data import default_data_collator, DataCollatorWithPadding
3131
from paddlenlp.data.sampler import SamplerHelper
3232
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
3333
from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer
@@ -196,10 +196,9 @@ def evaluate(model, loss_fct, metric, data_loader):
196196
model.eval()
197197
metric.reset()
198198
for batch in data_loader:
199-
input_ids, segment_ids, labels = batch
200-
logits = model(input_ids, segment_ids)
201-
loss = loss_fct(logits, labels)
202-
correct = metric.compute(logits, labels)
199+
logits = model(batch['input_ids'], batch['token_type_ids'])
200+
loss = loss_fct(logits, batch['labels'])
201+
correct = metric.compute(logits, batch['labels'])
203202
metric.update(correct)
204203
res = metric.accumulate()
205204
if isinstance(metric, AccuracyAndF1):
@@ -266,11 +265,7 @@ def preprocess_function(examples):
266265
remove_columns=columns)
267266
train_batch_sampler = paddle.io.DistributedBatchSampler(
268267
train_ds, batch_size=args.batch_size, shuffle=True)
269-
batchify_fn = lambda samples, fn=Dict({
270-
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
271-
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
272-
'labels': Stack(dtype="int64" if label_list else "float32") # label
273-
}): fn(samples)
268+
batchify_fn = DataCollatorWithPadding(tokenizer)
274269
train_data_loader = DataLoader(
275270
dataset=train_ds,
276271
batch_sampler=train_batch_sampler,
@@ -358,13 +353,11 @@ def preprocess_function(examples):
358353
for epoch in range(args.num_train_epochs):
359354
for step, batch in enumerate(train_data_loader):
360355
global_step += 1
361-
362-
input_ids, segment_ids, labels = batch
363356
with paddle.amp.auto_cast(
364357
args.use_amp,
365358
custom_white_list=["layer_norm", "softmax", "gelu"]):
366-
logits = model(input_ids, segment_ids)
367-
loss = loss_fct(logits, labels)
359+
logits = model(batch['input_ids'], batch['token_type_ids'])
360+
loss = loss_fct(logits, batch['labels'])
368361
if args.use_amp:
369362
scaler.scale(loss).backward()
370363
scaler.minimize(optimizer, loss)

examples/machine_reading_comprehension/SQuAD/run_squad.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import paddlenlp as ppnlp
3030

31-
from paddlenlp.data import Pad, Stack, Tuple, Dict
31+
from paddlenlp.data import default_data_collator, DataCollatorWithPadding
3232
from paddlenlp.transformers import BertForQuestionAnswering, BertTokenizer, ErnieForQuestionAnswering, ErnieTokenizer, FunnelForQuestionAnswering, FunnelTokenizer
3333
from paddlenlp.transformers import LinearDecayWithWarmup
3434
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction
@@ -170,19 +170,18 @@ def set_seed(args):
170170

171171

172172
@paddle.no_grad()
173-
def evaluate(model, data_loader, raw_dataset, args):
173+
def evaluate(model, data_loader, raw_dataset, features, args):
174174
model.eval()
175175

176176
all_start_logits = []
177177
all_end_logits = []
178178
tic_eval = time.time()
179179

180180
for batch in data_loader:
181-
input_ids, token_type_ids, attention_mask = batch
182181
start_logits_tensor, end_logits_tensor = model(
183-
input_ids,
184-
token_type_ids=token_type_ids,
185-
attention_mask=attention_mask)
182+
batch['input_ids'],
183+
token_type_ids=batch['token_type_ids'],
184+
attention_mask=batch['attention_mask'])
186185

187186
for idx in range(start_logits_tensor.shape[0]):
188187
if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
@@ -194,7 +193,7 @@ def evaluate(model, data_loader, raw_dataset, args):
194193
all_end_logits.append(end_logits_tensor.numpy()[idx])
195194

196195
all_predictions, all_nbest_json, scores_diff_json = compute_prediction(
197-
raw_dataset, data_loader.dataset, (all_start_logits, all_end_logits),
196+
raw_dataset, features, (all_start_logits, all_end_logits),
198197
args.version_2_with_negative, args.n_best_size, args.max_answer_length,
199198
args.null_score_diff_threshold)
200199

@@ -262,13 +261,7 @@ def run(args):
262261
num_proc=4)
263262
train_batch_sampler = paddle.io.DistributedBatchSampler(
264263
train_ds, batch_size=args.batch_size, shuffle=True)
265-
train_batchify_fn = lambda samples, fn=Dict({
266-
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
267-
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
268-
'attention_mask': Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
269-
"start_positions": Stack(dtype="int64"),
270-
"end_positions": Stack(dtype="int64")
271-
}): fn(samples)
264+
train_batchify_fn = DataCollatorWithPadding(tokenizer)
272265

273266
train_data_loader = DataLoader(
274267
dataset=train_ds,
@@ -304,12 +297,12 @@ def run(args):
304297
for epoch in range(num_train_epochs):
305298
for step, batch in enumerate(train_data_loader):
306299
global_step += 1
307-
input_ids, token_type_ids, attention_mask, start_positions, end_positions = batch
308300
logits = model(
309-
input_ids=input_ids,
310-
token_type_ids=token_type_ids,
311-
attention_mask=attention_mask)
312-
loss = criterion(logits, (start_positions, end_positions))
301+
input_ids=batch['input_ids'],
302+
token_type_ids=batch['token_type_ids'],
303+
attention_mask=batch['attention_mask'])
304+
loss = criterion(logits, (batch['start_positions'],
305+
batch['end_positions']))
313306
if global_step % args.logging_steps == 0:
314307
print(
315308
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
@@ -344,20 +337,17 @@ def run(args):
344337
num_proc=4)
345338
dev_batch_sampler = paddle.io.BatchSampler(
346339
dev_ds, batch_size=args.batch_size, shuffle=False)
347-
348-
dev_batchify_fn = lambda samples, fn=Dict({
349-
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
350-
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
351-
"attention_mask": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
352-
}): fn(samples)
340+
dev_ds_for_model = dev_ds.remove_columns(
341+
["example_id", "offset_mapping"])
342+
dev_batchify_fn = DataCollatorWithPadding(tokenizer)
353343

354344
dev_data_loader = DataLoader(
355-
dataset=dev_ds,
345+
dataset=dev_ds_for_model,
356346
batch_sampler=dev_batch_sampler,
357347
collate_fn=dev_batchify_fn,
358348
return_list=True)
359349

360-
evaluate(model, dev_data_loader, dev_examples, args)
350+
evaluate(model, dev_data_loader, dev_examples, dev_ds, args)
361351

362352

363353
if __name__ == "__main__":

0 commit comments

Comments
 (0)