Skip to content

Commit 7a5b002

Browse files
authored
More accurate score (#1940)
* match tokens to get accurate score * fix bugs * fix code style
1 parent 73d5318 commit 7a5b002

File tree

20 files changed

+690
-377
lines changed

20 files changed

+690
-377
lines changed

examples/model_interpretation/evaluation/accuracy/cal_acc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def cal_acc(golden_label, pred_label):
6868
"""
6969
acc = 0.0
7070
for ids in pred_label:
71+
if ids not in golden_label:
72+
continue
7173
if pred_label[ids] == golden_label[ids]:
7274
acc += 1
7375
if len(golden_label):
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
2+
3+
,
4+
5+
]
6+
7+
8+
!
9+
(
10+
11+
12+
13+
14+
÷
15+
16+
17+
18+
ˊ
19+
20+
.
21+
_
22+
@
23+
~
24+
25+
26+
27+
28+
29+
30+
31+
32+
33+
34+
35+
|
36+
37+
38+
39+
'
40+
41+
42+
43+
×
44+
45+
46+
47+
·
48+
49+
°
50+
>
51+
52+
53+
;
54+
55+
"
56+
57+
/
58+
<
59+
+
60+
61+
^
62+
63+
?
64+
[
65+
66+
67+
*
68+
69+
70+
:
71+
72+
)
73+
74+
75+
=
76+
-
77+
\
78+
%
79+
80+
&
81+
82+

examples/model_interpretation/rationale_extraction/generate_evaluation_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def r_data_generation(args, evids, text_dict_list, text_exclusive_dict_list,
9191
'context_idx']
9292
if len(temp['rationale']) > 1 and \
9393
args.inter_mode != 'lime' and \
94-
not (args.language == 'en' and args.base_model.startswith('roberta')):
94+
not (args.base_model.startswith('roberta')):
9595
for i in range(len(temp['rationale'][1])):
9696
temp['rationale'][1][i] -= len(temp['rationale'][0]) + len(temp[
9797
'no_rationale'][0])

examples/model_interpretation/rationale_extraction/run_2_pred_senti_per.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ do
4040
CKPT=../task/${TASK}/pretrained_models/saved_model_ch/roberta_large_20220318_170123/model_900/model_state.pdparams
4141
#CKPT=../../../${TASK}/pretrained_models/saved_model_ch/roberta_large_20211207_143351/model_900/model_state.pdparams
4242
elif [[ $BASE_MODEL == "lstm" ]]; then
43-
VOCAB_PATH=../task/${TASK}/rnn
43+
VOCAB_PATH=../task/${TASK}/rnn/vocab.txt
4444
CKPT=../task/${TASK}/rnn/checkpoints_ch/final.pdparams
4545
fi
4646
fi

examples/model_interpretation/rationale_extraction/sentiment_pred.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -182,20 +182,10 @@ def truncate_offset(seg, start_offset, end_offset):
182182

183183

184184
def init_lstm_var(args):
185-
#different language has different tokenizer
186-
if args.language == "ch":
187-
tokenizer = ErnieTokenizer.from_pretrained(args.vocab_path)
188-
padding_idx = tokenizer.vocab.get('[PAD]')
189-
tokenizer.inverse_vocab = [
190-
item[0]
191-
for item in sorted(
192-
tokenizer.vocab.items(), key=lambda x: x[1])
193-
]
194-
else:
195-
vocab = Vocab.load_vocabulary(
196-
args.vocab_path, unk_token='[UNK]', pad_token='[PAD]')
197-
tokenizer = CharTokenizer(vocab)
198-
padding_idx = vocab.token_to_idx.get('[PAD]', 0)
185+
vocab = Vocab.load_vocabulary(
186+
args.vocab_path, unk_token='[UNK]', pad_token='[PAD]')
187+
tokenizer = CharTokenizer(vocab, args.language, '../punctuations')
188+
padding_idx = vocab.token_to_idx.get('[PAD]', 0)
199189

200190
trans_fn = partial(
201191
convert_example,
@@ -299,23 +289,13 @@ def init_roberta_var(args):
299289
input_ids[0, 1:-1].tolist()) # list
300290

301291
elif args.base_model == 'lstm':
302-
if args.language == 'ch':
303-
input_ids, seq_lens = d
304-
#input_ids = paddle.to_tensor([input_ids[0][0]])
305-
fwd_args = [input_ids, seq_lens]
306-
fwd_kwargs = {}
307-
tokens = [
308-
tokenizer.inverse_vocab[input_id]
309-
for input_id in input_ids.tolist()[0]
310-
]
311-
else:
312-
input_ids, seq_lens = d
313-
fwd_args = [input_ids, seq_lens]
314-
fwd_kwargs = {}
315-
tokens = [
316-
tokenizer.vocab.idx_to_token[input_id]
317-
for input_id in input_ids.tolist()[0]
318-
]
292+
input_ids, seq_lens = d
293+
fwd_args = [input_ids, seq_lens]
294+
fwd_kwargs = {}
295+
tokens = [
296+
tokenizer.vocab.idx_to_token[input_id]
297+
for input_id in input_ids.tolist()[0]
298+
]
319299

320300
result['id'] = dataloader.dataset.data[step]['id']
321301

examples/model_interpretation/rationale_extraction/similarity_pred.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def init_lstm_var(args):
186186
unk_token='[UNK]',
187187
pad_token='[PAD]')
188188

189-
tokenizer = CharTokenizer(vocab, language=args.language)
189+
tokenizer = CharTokenizer(vocab, args.language, '../punctuations')
190190
model = SimNet(network='lstm', vocab_size=len(vocab), num_classes=2)
191191

192192
dev_ds = SimilarityData().read(os.path.join(args.data_dir, 'dev'))

examples/model_interpretation/task/mrc/saliency_map/rc_finetune.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,17 @@ def map_fn_DuCheckList_finetune(examples):
150150
# Start/end character index of the answer in the text.
151151
start_char = answer_starts[0]
152152
end_char = start_char + len(answers[0])
153-
if args.language == 'ch':
153+
if args.language == 'en':
154+
# Start token index of the current span in the text.
155+
token_start_index = 0
156+
while not (offsets[token_start_index] ==
157+
(0, 0) and offsets[token_start_index + 1] == (0, 0)):
158+
token_start_index += 1
159+
token_start_index += 2
160+
161+
# End token index of the current span in the text.
162+
token_end_index = len(input_ids) - 2
163+
else:
154164
# Start token index of the current span in the text.
155165
token_start_index = 0
156166
while sequence_ids[token_start_index] != 1:
@@ -160,10 +170,6 @@ def map_fn_DuCheckList_finetune(examples):
160170
token_end_index = len(input_ids) - 2
161171
while sequence_ids[token_end_index] != 1:
162172
token_end_index -= 1
163-
else:
164-
165-
token_start_index = tokenized_example['context_start_id']
166-
token_end_index = tokenized_example['context_end_id']
167173

168174
# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
169175
if not (offsets[token_start_index][0] <= start_char and

0 commit comments

Comments
 (0)