Skip to content

Commit 6f0535a

Browse files
authored
Calc CHID acc by considering the global probability (#1891)
* compute acc by considering the global probability * remove debug var * remove debug var * update readme results * fix cmrc2018 max answer length * update readme * add lr,bs in readme
1 parent 07c0085 commit 6f0535a

File tree

3 files changed

+98
-46
lines changed

3 files changed

+98
-46
lines changed

examples/benchmark/clue/README.md

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,31 @@
99
使用多种中文预训练模型微调在 CLUE 的各验证集上有如下结果:
1010

1111

12-
| Model | AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | CLUEWSC2020 | CSL | C<sup>3</sup> |
13-
| --------------------- | ----- | ----- | ------- | ----- | ----- | ----------- | ----- | ------------- |
14-
| RoBERTa-wwm-ext-large | 76.20 | 59.50 | 62.10 | 84.02 | 79.15 | 90.79 | 82.03 | 75.79 |
12+
| Model | AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | CLUEWSC2020 | CSL | CMRC2018 | CHID | C<sup>3</sup> |
13+
| --------------------- | ----- | ----- | ------- | ----- | ----- | ----------- | ----- | ----------- | ----- | ------------- |
14+
| RoBERTa-wwm-ext-large | 75.32 | 59.33 | 61.91 | 83.87 | 78.81 | 91.78 | 81.80 | 70.67/90.61 | 85.83 | 74.90 |
1515

1616

17-
AFQMC、TNEWS、IFLYTEK、CMNLI、OCNLI、CLUEWSC2020、CSL 和 C<sup>3</sup> 任务使用的评估指标均是 Accuracy。
18-
其中前 7 项属于分类任务,后面 1 项属于阅读理解任务,这两种任务的训练过程在下面将会分开介绍。
17+
AFQMC、TNEWS、IFLYTEK、CMNLI、OCNLI、CLUEWSC2020、CSL 、CHID 和 C<sup>3</sup> 任务使用的评估指标均是 Accuracy。CMRC2018 的评估指标是 EM/F1
18+
其中前 7 项属于分类任务,后面 3 项属于阅读理解任务,这两种任务的训练过程在下面将会分开介绍。
1919

2020
**NOTE:具体评测方式如下**
2121
1. 以上所有任务均基于 Grid Search 方式进行超参寻优。分类任务训练每间隔 100 steps 评估验证集效果,阅读理解任务每隔一个 epoch 评估验证集效果,取验证集最优效果作为表格中的汇报指标。
2222

2323
2. 分类任务 Grid Search 超参范围: batch_size: 16, 32, 64; learning rates: 1e-5, 2e-5, 3e-5, 5e-5;因为 CLUEWSC2020 数据集效果对 batch_size 较为敏感,对CLUEWSC2020 评测时额外增加了 batch_size = 8 的超参搜索。
2424

25-
3. 阅读理解任务 Grid Search 超参范围:batch_size: 24, 32; learning rates: 1e-5, 2e-5, 3e-5。
25+
3. 阅读理解任务 Grid Search 超参范围:batch_size: 24, 32; learning rates: 1e-5, 2e-5, 3e-5。阅读理解任务均使用多卡训练,其中 Grid Search 中的 batch_size 是指多张卡上的 batch_size 总和。
2626

27-
4. 以上任务的 epoch、max_seq_length、warmup proportion 如下表所示
27+
4. 以上每个任务的固定超参配置如下表所示
2828

2929
| TASK | AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | CLUEWSC2020 | CSL | CMRC2018 | CHID | C<sup>3</sup> |
3030
| ----------------- | ----- | ----- | ------- | ----- | ----- | ----------- | ---- | -------- | ---- | ------------- |
3131
| epoch | 3 | 3 | 3 | 2 | 5 | 50 | 5 | 2 | 3 | 8 |
3232
| max_seq_length | 128 | 128 | 128 | 128 | 128 | 128 | 128 | 512 | 64 | 512 |
33-
| warmup_proportion | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.06 | 0.05 |
34-
33+
| warmup_proportion | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 0.06 | 0.1 |
34+
| num_cards | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 2 | 4 | 4 |
35+
| learning_rate | 1e-5 | 3e-5 | 3e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 32 | 24 | 24 |
36+
| batch_size | 32 | 32 | 32 | 16 | 16 | 16 | 16 | 3e-5 | 1e-5 | 2e-5 |
3537

3638

3739
## 一键复现模型效果
@@ -100,26 +102,29 @@ eval loss: 2.476962, acc: 0.1697, eval done total : 25.794789791107178 s
100102
```
101103

102104
### 启动 CLUE 阅读理解任务
103-
以 CLUE 的 C<sup>3</sup> 任务为例,启动 CLUE 任务进行 Fine-tuning 的方式如下:
105+
以 CLUE 的 C<sup>3</sup> 任务为例,多卡启动 CLUE 任务进行 Fine-tuning 的方式如下:
104106

105107
```shell
106108

107109
cd mrc
108110

109111
mkdir roberta-wwm-ext-large
110112
MODEL_PATH=roberta-wwm-ext-large
111-
BATCH_SIZE=24
113+
BATCH_SIZE=6
112114
LR=2e-5
113115

114-
python -u run_c3.py \
116+
python -m paddle.distributed.launch --gpus "0,1,2,3" run_c3.py \
115117
--model_name_or_path ${MODEL_PATH} \
116118
--batch_size ${BATCH_SIZE} \
117119
--learning_rate ${LR} \
118120
--max_seq_length 512 \
119121
--num_train_epochs 8 \
120-
--warmup_proportion 0.05 \
122+
--do_train \
123+
--warmup_proportion 0.1 \
124+
--gradient_accumulation_steps 3 \
121125

122126
```
127+
需要注意的是,如果显存无法容纳所传入的 `batch_size`,可以通过传入 `gradient_accumulation_steps` 参数来模拟该 `batch_size`
123128

124129
## 参加 CLUE 竞赛
125130

examples/benchmark/clue/mrc/run_chid.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@
2424
import numpy as np
2525

2626
import paddle
27-
from paddle.metric import Accuracy
2827
import paddle.nn as nn
2928

3029
from datasets import load_dataset
31-
3230
from paddlenlp.data import Pad, Stack, Tuple, Dict
3331
from paddlenlp.transformers import AutoModelForMultipleChoice, AutoTokenizer
3432
from paddlenlp.transformers import LinearDecayWithWarmup
@@ -140,19 +138,64 @@ def set_seed(args):
140138
paddle.seed(args.seed)
141139

142140

141+
def calc_global_pred_results(logits):
142+
logits = np.array(logits)
143+
# [num_choices, tag_size]
144+
logits = np.transpose(logits)
145+
tmp = []
146+
for i, row in enumerate(logits):
147+
for j, col in enumerate(row):
148+
tmp.append((i, j, col))
149+
else:
150+
choice = set(range(i + 1))
151+
blanks = set(range(j + 1))
152+
tmp = sorted(tmp, key=lambda x: x[2], reverse=True)
153+
results = []
154+
for i, j, v in tmp:
155+
if (j in blanks) and (i in choice):
156+
results.append((i, j))
157+
blanks.remove(j)
158+
choice.remove(i)
159+
results = sorted(results, key=lambda x: x[1], reverse=False)
160+
results = [i for i, j in results]
161+
return results
162+
163+
143164
@paddle.no_grad()
144-
def evaluate(model, loss_fct, metric, data_loader):
165+
def evaluate(model, data_loader, do_predict=False):
145166
model.eval()
146-
metric.reset()
167+
right_num, total_num = 0, 0
168+
all_results = []
147169
for step, batch in enumerate(data_loader):
148-
input_ids, segment_ids, labels = batch
170+
if do_predict:
171+
input_ids, segment_ids, example_ids = batch
172+
else:
173+
input_ids, segment_ids, labels, example_ids = batch
149174
logits = model(input_ids=input_ids, token_type_ids=segment_ids)
150-
loss = loss_fct(logits, labels)
151-
correct = metric.compute(logits, labels)
152-
metric.update(correct)
153-
res = metric.accumulate()
175+
batch_num = example_ids.shape[0]
176+
l = 0
177+
r = batch_num - 1
178+
batch_results = []
179+
for i in range(batch_num - 1):
180+
if example_ids[i] != example_ids[i + 1]:
181+
r = i
182+
batch_results.extend(
183+
calc_global_pred_results(logits[l:r + 1, :]))
184+
l = i + 1
185+
if l <= batch_num - 1:
186+
batch_results.extend(
187+
calc_global_pred_results(logits[l:batch_num, :]))
188+
if do_predict:
189+
all_results.extend(batch_results)
190+
else:
191+
right_num += np.sum(np.array(batch_results) == labels.numpy())
192+
total_num += labels.shape[0]
154193
model.train()
155-
return res
194+
if not do_predict:
195+
acc = right_num / total_num
196+
print("acc", right_num, total_num, acc)
197+
return acc
198+
return all_results
156199

157200

158201
def run(args):
@@ -242,9 +285,14 @@ def add_tokens_for_around(tokens, pos, num_tokens):
242285
num_tokens = max_tokens_for_doc - 5
243286
num_examples = len(examples.data["candidates"])
244287
if do_predict:
245-
result = {"input_ids": [], "token_type_ids": []}
288+
result = {"input_ids": [], "token_type_ids": [], "example_ids": []}
246289
else:
247-
result = {"input_ids": [], "token_type_ids": [], "labels": []}
290+
result = {
291+
"input_ids": [],
292+
"token_type_ids": [],
293+
"labels": [],
294+
"example_ids": []
295+
}
248296
for idx in range(num_examples):
249297
candidate = 0
250298
options = examples.data['candidates'][idx]
@@ -316,6 +364,7 @@ def add_tokens_for_around(tokens, pos, num_tokens):
316364
# Final shape of input_ids: [batch_size, num_choices, seq_len]
317365
result["input_ids"].append(new_data["input_ids"])
318366
result["token_type_ids"].append(new_data["token_type_ids"])
367+
result["example_ids"].append(idx)
319368
if not do_predict:
320369
label = examples.data["answers"][idx]["candidate_id"][
321370
candidate]
@@ -350,7 +399,8 @@ def add_tokens_for_around(tokens, pos, num_tokens):
350399
batchify_fn = lambda samples, fn=Dict({
351400
'input_ids': Pad(axis=1, pad_val=tokenizer.pad_token_id), # input
352401
'token_type_ids': Pad(axis=1, pad_val=tokenizer.pad_token_type_id), # segment
353-
'labels': Stack(dtype="int64") # label
402+
'labels': Stack(dtype="int64"), # label
403+
'example_ids': Stack(dtype="int64"), # example id
354404
}): fn(samples)
355405

356406
train_batch_sampler = paddle.io.DistributedBatchSampler(
@@ -397,15 +447,14 @@ def add_tokens_for_around(tokens, pos, num_tokens):
397447
grad_clip=grad_clip)
398448

399449
loss_fct = nn.CrossEntropyLoss()
400-
metric = Accuracy()
401450

402451
model.train()
403452
global_step = 0
404453
best_acc = 0.0
405454
tic_train = time.time()
406455
for epoch in range(args.num_train_epochs):
407456
for step, batch in enumerate(train_data_loader):
408-
input_ids, segment_ids, labels = batch
457+
input_ids, segment_ids, labels, example_ids = batch
409458
logits = model(input_ids=input_ids, token_type_ids=segment_ids)
410459
loss = loss_fct(logits, labels)
411460
if args.gradient_accumulation_steps > 1:
@@ -424,7 +473,7 @@ def add_tokens_for_around(tokens, pos, num_tokens):
424473
args.logging_steps / (time.time() - tic_train)))
425474
tic_train = time.time()
426475
tic_eval = time.time()
427-
acc = evaluate(model, loss_fct, metric, dev_data_loader)
476+
acc = evaluate(model, dev_data_loader)
428477
print("eval acc: %.5f, eval done total : %s s" %
429478
(acc, time.time() - tic_eval))
430479
if paddle.distributed.get_rank() == 0 and acc > best_acc:
@@ -445,13 +494,13 @@ def add_tokens_for_around(tokens, pos, num_tokens):
445494
batch_size=len(test_ds),
446495
remove_columns=column_names,
447496
num_proc=1)
448-
449497
test_batch_sampler = paddle.io.BatchSampler(
450498
test_ds, batch_size=args.eval_batch_size, shuffle=False)
451499

452500
batchify_fn = lambda samples, fn=Dict({
453501
'input_ids': Pad(axis=1, pad_val=tokenizer.pad_token_id), # input
454502
'token_type_ids': Pad(axis=1, pad_val=tokenizer.pad_token_type_id), # segment
503+
'example_ids': Stack(dtype="int64"), # example id
455504
}): fn(samples)
456505

457506
test_data_loader = paddle.io.DataLoader(
@@ -462,15 +511,10 @@ def add_tokens_for_around(tokens, pos, num_tokens):
462511

463512
result = {}
464513
idx = 623377
465-
for step, batch in enumerate(test_data_loader):
466-
input_ids, segment_ids = batch
467-
with paddle.no_grad():
468-
logits = model(input_ids, segment_ids)
469-
preds = paddle.argmax(logits, axis=1).numpy().tolist()
470-
for pred in preds:
471-
result["#idiom" + str(idx)] = pred
472-
idx += 1
473-
514+
preds = evaluate(model, test_data_loader, do_predict=True)
515+
for pred in preds:
516+
result["#idiom" + str(idx)] = pred
517+
idx += 1
474518
if not os.path.exists(args.output_dir):
475519
os.makedirs(args.output_dir)
476520
with open(

examples/benchmark/clue/mrc/run_cmrc.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def parse_args():
130130
parser.add_argument(
131131
"--max_query_length", type=int, default=64, help="Max query length.")
132132
parser.add_argument(
133-
"--max_answer_length", type=int, default=30, help="Max answer length.")
133+
"--max_answer_length", type=int, default=50, help="Max answer length.")
134134
parser.add_argument(
135135
"--do_lower_case",
136136
action='store_false',
@@ -184,12 +184,15 @@ def evaluate(model, raw_dataset, data_loader, args, do_eval=True):
184184
raw_dataset, data_loader.dataset, (all_start_logits, all_end_logits),
185185
False, args.n_best_size, args.max_answer_length)
186186

187-
# Can also write all_nbest_json and scores_diff_json files if needed
188-
if args.do_predict:
189-
with open('cmrc2018_predict.json', "w", encoding='utf-8') as writer:
190-
writer.write(
191-
json.dumps(
192-
all_predictions, ensure_ascii=False, indent=4) + "\n")
187+
mode = 'validation' if do_eval else 'test'
188+
if do_eval:
189+
filename = 'prediction_validation.json'
190+
else:
191+
filename = 'cmrc2018_predict.json'
192+
with open(filename, "w", encoding='utf-8') as writer:
193+
writer.write(
194+
json.dumps(
195+
all_predictions, ensure_ascii=False, indent=4) + "\n")
193196
if do_eval:
194197
squad_evaluate(
195198
examples=[raw_data for raw_data in raw_dataset],

0 commit comments

Comments
 (0)