|
18 | 18 |
|
19 | 19 | import paddle
|
20 | 20 | from paddle import inference
|
21 |
| -from paddlenlp.datasets import load_dataset |
22 |
| -from paddlenlp.data import Stack, Tuple, Pad |
| 21 | +from datasets import load_dataset |
| 22 | +from paddlenlp.data import Stack, Tuple, Pad, Dict |
23 | 23 |
|
24 |
| -from run_glue import convert_example, METRIC_CLASSES, MODEL_CLASSES |
| 24 | +from run_glue import METRIC_CLASSES, MODEL_CLASSES, task_to_keys |
25 | 25 |
|
26 | 26 |
|
27 | 27 | def parse_args():
|
@@ -134,22 +134,27 @@ def main():
|
134 | 134 | args.task_name = args.task_name.lower()
|
135 | 135 | args.model_type = args.model_type.lower()
|
136 | 136 | model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
| 137 | + sentence1_key, sentence2_key = task_to_keys[args.task_name] |
137 | 138 |
|
138 |
| - test_ds = load_dataset('glue', args.task_name, splits="test") |
| 139 | + test_ds = load_dataset('glue', args.task_name, split="test") |
139 | 140 | tokenizer = tokenizer_class.from_pretrained(
|
140 | 141 | os.path.dirname(args.model_path))
|
141 | 142 |
|
142 |
| - trans_func = partial( |
143 |
| - convert_example, |
144 |
| - tokenizer=tokenizer, |
145 |
| - label_list=test_ds.label_list, |
146 |
| - max_seq_length=args.max_seq_length, |
147 |
| - is_test=True) |
148 |
| - test_ds = test_ds.map(trans_func) |
149 |
| - batchify_fn = lambda samples, fn=Tuple( |
150 |
| - Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input |
151 |
| - Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"), # segment |
152 |
| - ): fn(samples) |
| 143 | + def preprocess_function(examples): |
| 144 | + # Tokenize the texts |
| 145 | + texts = ((examples[sentence1_key], ) if sentence2_key is None else |
| 146 | + (examples[sentence1_key], examples[sentence2_key])) |
| 147 | + result = tokenizer(*texts, max_seq_len=args.max_seq_length) |
| 148 | + if "label" in examples: |
| 149 | + # In all cases, rename the column to labels because the model will expect that. |
| 150 | + result["labels"] = examples["label"] |
| 151 | + return result |
| 152 | + |
| 153 | + test_ds = test_ds.map(preprocess_function) |
| 154 | + batchify_fn = lambda samples, fn=Dict({ |
| 155 | + 'input_ids': Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"), # input |
| 156 | + 'token_type_ids': Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"), # segment |
| 157 | + }): fn(samples) |
153 | 158 | predictor.predict(
|
154 | 159 | test_ds, batch_size=args.batch_size, collate_fn=batchify_fn)
|
155 | 160 |
|
|
0 commit comments