Skip to content

Commit fd52aff

Browse files
authored
Merge branch 'develop' into deploy
2 parents a1e7dff + a6e56e3 commit fd52aff

File tree

15 files changed

+40
-28
lines changed

15 files changed

+40
-28
lines changed

examples/model_interpretation/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,17 @@ pip3 install paddlepaddle-gpu
198198
sample_type: 数据的类性,分为原始数据(ori)和扰动数据(disturb);
199199
rel_ids:与原始数据关联的扰动数据的id列表(只有原始数据有);
200200
## 模型运行
201-
### 证据抽取运行
201+
### 模型预测
202202

203203
model_interpretation/task/{task}/run_inter_all.sh (生成所有结果)
204204
model_interpretation/task/{task}/run_inter.sh (生成单个配置的结果,配置可以选择不同的评估模型,以及不同的证据抽取方法、语言)
205205

206206
(注:{task}可取值为["senti","similarity","mrc"],其中senti代表情感分析,similarity代表相似度计算,mrc代表阅读理解)
207+
208+
### 证据抽取:
209+
cd model_interpretation/rationale_extraction
210+
./generate.sh
211+
207212
### 可解释评估:
208213
#### 合理性(plausibility):
209214
model_interpretation/evaluation/plausibility/run_f1.sh

examples/model_interpretation/rationale_extraction/run_2_pred_senti_per.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ do
2626
#CKPT=../../../${TASK}/pretrained_models/saved_model_en/roberta_large_20211207_174631/model_4000/model_state.pdparams
2727
elif [[ $BASE_MODEL == "lstm" ]]; then
2828
VOCAB_PATH=../task/${TASK}/rnn/vocab.sst2_train
29-
CKPT=../task/${TASK}/rnn/checkpoints_en_ori/final.pdparams
30-
#CKPT=../../../${TASK}/rnn/checkpoints_en/final.pdparams
29+
CKPT=../task/${TASK}/rnn/checkpoints_en/final.pdparams
3130
fi
3231

3332
elif [[ $LANGUAGE == "ch" ]]; then
@@ -42,8 +41,7 @@ do
4241
#CKPT=../../../${TASK}/pretrained_models/saved_model_ch/roberta_large_20211207_143351/model_900/model_state.pdparams
4342
elif [[ $BASE_MODEL == "lstm" ]]; then
4443
VOCAB_PATH=../task/${TASK}/rnn
45-
CKPT=../task/${TASK}/rnn/checkpoints_ch_ori/final.pdparams
46-
#CKPT=../../../${TASK}/rnn/checkpoints_ch/final.pdparams
44+
CKPT=../task/${TASK}/rnn/checkpoints_ch/final.pdparams
4745
fi
4846
fi
4947

examples/model_interpretation/task/mrc/run_train_rc.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ export CUDA_VISIBLE_DEVICES=7
66
export PYTHONPATH=.:$PYTHONPATH
77

88
LANGUAGE=ch # LANGUAGE choose in [ch, en]
9-
BASE_MODEL=roberta_large # chooices [roberta_base, roberta_large]
9+
BASE_MODEL=roberta_base # chooices [roberta_base, roberta_large]
10+
11+
[ -d "logs" ] || mkdir -p "logs"
12+
set -x
1013

1114
if [[ $LANGUAGE == "ch" ]]; then
1215
if [[ $BASE_MODEL == "roberta_base" ]]; then

examples/model_interpretation/task/mrc/saliency_map/rc_finetune.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,10 @@ def map_fn_DuCheckList_finetune(examples):
298298
if step % 1000 == 0:
299299
if args.save_dir is not None:
300300
paddle.save(model.state_dict(),
301-
args.save_dir / 'ckpt.bin')
301+
os.path.join(args.save_dir, 'ckpt.bin'))
302302
log.debug('save model!')
303303

304304
if args.save_dir is not None:
305-
paddle.save(model.state_dict(), args.save_dir / 'ckpt.bin')
305+
paddle.save(model.state_dict(),
306+
os.path.join(args.save_dir, 'ckpt.bin'))
306307
log.debug('save model!')

examples/model_interpretation/task/senti/pretrained_models/run_train.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ elif [[ $LANGUAGE == "en" ]]; then
1616
MAX_SEQ_LENGTH=512
1717
fi
1818

19+
[ -d "logs" ] || mkdir -p "logs"
20+
set -x
21+
1922
python3 ./train.py \
2023
--learning_rate ${LEARNING_RATE} \
2124
--max_seq_length ${MAX_SEQ_LENGTH} \

examples/model_interpretation/task/senti/pretrained_models/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
from paddlenlp.transformers import LinearDecayWithWarmup
3232
from paddlenlp.transformers.roberta.tokenizer import RobertaTokenizer, RobertaBPETokenizer
3333
sys.path.append('..')
34+
sys.path.append('../../..')
3435
from roberta.modeling import RobertaForSequenceClassification
36+
sys.path.remove('../../..')
3537
sys.path.remove('..')
3638
from utils import convert_example
3739

examples/model_interpretation/task/senti/rnn/lstm_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ unset CUDA_VISIBLE_DEVICES
66
LANGUAGE=en
77

88
if [[ $LANGUAGE == 'ch' ]]; then
9-
VOCAB_PATH='./vocab.txt'
9+
VOCAB_PATH='./'
1010
else
1111
VOCAB_PATH='vocab.sst2_train'
1212
fi

examples/model_interpretation/task/senti/rnn/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from model import LSTMModel, SelfInteractiveAttention, BiLSTMAttentionModel
2727
from utils import convert_example, CharTokenizer
28-
from paddlenlp.transformers.ernie.tokenizer import ErnieTokenizer
28+
from ernie.tokenizing_ernie import ErnieTokenizer
2929

3030
# yapf: disable
3131
parser = argparse.ArgumentParser(__doc__)

examples/model_interpretation/task/senti/run_inter.sh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ if [[ $LANGUAGE == "en" ]]; then
2727
#CKPT=pretrained_models/saved_model_en/roberta_large_20211207_174631/model_4000/model_state.pdparams
2828
elif [[ $BASE_MODEL == "lstm" ]]; then
2929
VOCAB_PATH='rnn/vocab.sst2_train'
30-
CKPT=rnn/checkpoints_en_ori/final.pdparams
31-
#CKPT=rnn/checkpoints_en/final.pdparams
30+
CKPT=rnn/checkpoints_en/final.pdparams
3231
fi
3332

3433
elif [[ $LANGUAGE == "ch" ]]; then
@@ -43,8 +42,7 @@ elif [[ $LANGUAGE == "ch" ]]; then
4342
#CKPT=pretrained_models/saved_model_ch/roberta_large_20211229_105019/model_900/model_state.pdparams
4443
elif [[ $BASE_MODEL == "lstm" ]]; then
4544
VOCAB_PATH='rnn'
46-
CKPT=rnn/checkpoints_ch_ori/final.pdparams
47-
#CKPT=rnn/checkpoints_ch/final.pdparams
45+
CKPT=rnn/checkpoints_ch/final.pdparams
4846
fi
4947
fi
5048

examples/model_interpretation/task/senti/run_inter_all.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ do
3030
#CKPT=pretrained_models/saved_model_en/roberta_large_20211207_174631/model_4000/model_state.pdparams
3131
elif [[ $BASE_MODEL == "lstm" ]]; then
3232
VOCAB_PATH='rnn/vocab.sst2_train'
33-
CKPT=rnn/checkpoints_en_ori/final.pdparams
33+
CKPT=rnn/checkpoints_en/final.pdparams
3434
#CKPT=rnn/checkpoints_en/final.pdparams
3535
fi
3636

@@ -46,7 +46,7 @@ do
4646
#CKPT=pretrained_models/saved_model_ch/roberta_large_20211229_105019/model_900/model_state.pdparams
4747
elif [[ $BASE_MODEL == "lstm" ]]; then
4848
VOCAB_PATH='rnn'
49-
CKPT=rnn/checkpoints_ch_ori/final.pdparams
49+
CKPT=rnn/checkpoints_ch/final.pdparams
5050
#CKPT=rnn/checkpoints_ch/final.pdparams
5151
fi
5252
fi

0 commit comments

Comments
 (0)