Skip to content

Commit e93f899

Browse files
authored
upgrade predict_glue.py (#1779)
* upgrade predict_glue.py * test doc api * revert
1 parent bdd703b commit e93f899

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

examples/language_model/bert/predict_glue.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818

1919
import paddle
2020
from paddle import inference
21-
from paddlenlp.datasets import load_dataset
22-
from paddlenlp.data import Stack, Tuple, Pad
21+
from datasets import load_dataset
22+
from paddlenlp.data import Stack, Tuple, Pad, Dict
2323

24-
from run_glue import convert_example, METRIC_CLASSES, MODEL_CLASSES
24+
from run_glue import METRIC_CLASSES, MODEL_CLASSES, task_to_keys
2525

2626

2727
def parse_args():
@@ -134,22 +134,27 @@ def main():
134134
args.task_name = args.task_name.lower()
135135
args.model_type = args.model_type.lower()
136136
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
137+
sentence1_key, sentence2_key = task_to_keys[args.task_name]
137138

138-
test_ds = load_dataset('glue', args.task_name, splits="test")
139+
test_ds = load_dataset('glue', args.task_name, split="test")
139140
tokenizer = tokenizer_class.from_pretrained(
140141
os.path.dirname(args.model_path))
141142

142-
trans_func = partial(
143-
convert_example,
144-
tokenizer=tokenizer,
145-
label_list=test_ds.label_list,
146-
max_seq_length=args.max_seq_length,
147-
is_test=True)
148-
test_ds = test_ds.map(trans_func)
149-
batchify_fn = lambda samples, fn=Tuple(
150-
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
151-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"), # segment
152-
): fn(samples)
143+
def preprocess_function(examples):
144+
# Tokenize the texts
145+
texts = ((examples[sentence1_key], ) if sentence2_key is None else
146+
(examples[sentence1_key], examples[sentence2_key]))
147+
result = tokenizer(*texts, max_seq_len=args.max_seq_length)
148+
if "label" in examples:
149+
# In all cases, rename the column to labels because the model will expect that.
150+
result["labels"] = examples["label"]
151+
return result
152+
153+
test_ds = test_ds.map(preprocess_function)
154+
batchify_fn = lambda samples, fn=Dict({
155+
'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input
156+
'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"), # segment
157+
}): fn(samples)
153158
predictor.predict(
154159
test_ds, batch_size=args.batch_size, collate_fn=batchify_fn)
155160

0 commit comments

Comments
 (0)