Skip to content

Commit 24eceed

Browse files
author
yunfan
committed
fix some typo
1 parent c9238aa commit 24eceed

4 files changed

Lines changed: 12 additions & 8 deletions

File tree

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ Then, use the PTMs as the following example, where `MODEL_NAME` is the correspon
4747

4848
For CPT:
4949
```python
50-
from modeling_cpt import BertTokenizer, CPTForConditionalGeneration
50+
from modeling_cpt import CPTForConditionalGeneration
51+
from transformers import BertTokenizer
5152
tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
5253
model = CPTForConditionalGeneration.from_pretrained("MODEL_NAME")
5354
print(model)
@@ -63,9 +64,9 @@ print(model)
6364

6465
After initializing the model, you can use the following lines to generate text.
6566
```python
66-
>>> inputs = tokenizer.encode("北京是[MASK]的首都", return_tensors='pt')
67+
>>> input_ids = tokenizer.encode("北京是[MASK]的首都", return_tensors='pt')
6768
>>> pred_ids = model.generate(input_ids, num_beams=4, max_length=20)
68-
>>> print(tokenizer.convert_ids_to_tokens(pred_ids[i]))
69+
>>> print(tokenizer.convert_ids_to_tokens(pred_ids[0]))
6970
['[SEP]', '[CLS]', '', '', '', '', '', '', '', '', '[SEP]']
7071
```
7172

finetune/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ This repo contains the fine-tuning code for CPT on multiple NLU and NLG tasks, s
55
## Requirements
66
- pytorch==1.8.1
77
- transformers==4.4.1
8+
- fitlog
9+
- fastNLP
810

911
## Run
1012
The code and running examples are listed in the corresponding folders of the fine-tuning tasks.
@@ -18,7 +20,8 @@ The code and running examples are listed in the corresponding folders of the fin
1820
You can also fine-tuning CPT on other tasks by adding `modeling_cpt.py` into your project and use the following code to use CPT.
1921

2022
```python
21-
from modeling_cpt import BertTokenizer, CPTForConditionalGeneration
23+
from modeling_cpt import CPTForConditionalGeneration
24+
from transformers import BertTokenizer
2225
tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
2326
model = CPTForConditionalGeneration.from_pretrained("MODEL_NAME")
2427
print(model)

finetune/mrc/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ To train and evaluate **CPT$_u$**, **CPT$_g$** and **CPT$_{ug}$**, run the pytho
99
```bash
1010
export MODEL_TYPE=cpt-base
1111
export MODEL_NAME=fnlp/cpt-base
12-
export CLUE_DATA_DIR=/path/to/mrc_data_dir
12+
export CLUE_DATA_DIR=~/workdir/datasets/CLUEdatasets/
1313
export TASK_NAME=drcd
1414
export CLS_MODE=1
1515
python run_mrc.py \
@@ -22,7 +22,7 @@ python run_mrc.py \
2222
--gradient_accumulation_steps 4 \
2323
--lr=3e-5 \
2424
--dropout=0.2 \
25-
--CLS_MODE=$CLS_MODE \
25+
--cls_mode=$CLS_MODE \
2626
--warmup_rate=0.1 \
2727
--weight_decay_rate=0.01 \
2828
--max_seq_length=512 \

finetune/mrc/run_mrc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def test(model, args, eval_examples, eval_features, device, name):
210210
torch.distributed.barrier()
211211

212212
# load the bert setting
213-
if 'bert' == args.model_type:
213+
if 'bert' in args.model_type or 'cpt' in args.model_type:
214214
if 'large' in args.init_restore_dir or '24' in args.init_restore_dir:
215215
config_path = 'hfl/chinese-roberta-wwm-ext-large'
216216
else:
@@ -219,7 +219,7 @@ def test(model, args, eval_examples, eval_features, device, name):
219219
tokenizer = BertTokenizer.from_pretrained(config_path)
220220
bert_config.hidden_dropout_prob = args.dropout
221221
bert_config.attention_probs_dropout_prob = args.dropout
222-
if 'arch' in args.init_restore_dir:
222+
if 'cpt' in args.init_restore_dir:
223223
config = CPTConfig.from_pretrained(args.init_restore_dir)
224224
config.cls_mode = args.cls_mode
225225
config.attention_dropout = args.dropout

0 commit comments

Comments
 (0)