Skip to content

Commit 6f5c287

Browse files
authored
[speech-cmd] use uie in transformers (#4278)
1 parent 91ab453 commit 6f5c287

File tree

3 files changed

+16
-86
lines changed

3 files changed

+16
-86
lines changed

applications/speech_cmd_analysis/finetune.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,16 @@
1313
# limitations under the License.
1414

1515
import argparse
16-
import time
1716
import os
17+
import time
1818
from functools import partial
1919

2020
import paddle
21-
from paddle.utils.download import get_path_from_url
21+
from utils import convert_example, create_dataloader, evaluate, reader, set_seed
22+
2223
from paddlenlp.datasets import load_dataset
23-
from paddlenlp.transformers import AutoTokenizer
2424
from paddlenlp.metrics import SpanEvaluator
25-
26-
from model import UIE
27-
from utils import set_seed, convert_example, reader, MODEL_MAP, evaluate, create_dataloader
25+
from paddlenlp.transformers import UIE, AutoTokenizer
2826

2927

3028
def do_train():
@@ -35,15 +33,7 @@ def do_train():
3533

3634
set_seed(args.seed)
3735

38-
encoding_model = MODEL_MAP[args.model]["encoding_model"]
39-
resource_file_urls = MODEL_MAP[args.model]["resource_file_urls"]
40-
41-
for key, val in resource_file_urls.items():
42-
file_path = os.path.join(args.model, key)
43-
if not os.path.exists(file_path):
44-
get_path_from_url(val, args.model)
45-
46-
tokenizer = AutoTokenizer.from_pretrained(encoding_model)
36+
tokenizer = AutoTokenizer.from_pretrained(args.model)
4737
model = UIE.from_pretrained(args.model)
4838

4939
if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt):
@@ -71,7 +61,6 @@ def do_train():
7161

7262
loss_list = []
7363
global_step = 0
74-
best_step = 0
7564
best_f1 = 0
7665
tic_train = time.time()
7766
for epoch in range(1, args.num_epochs + 1):
@@ -123,8 +112,7 @@ def do_train():
123112
parser.add_argument("--train_path", default=None, type=str, help="The path of train set.")
124113
parser.add_argument("--dev_path", default=None, type=str, help="The path of dev set.")
125114
parser.add_argument("--save_dir", default='./checkpoint', type=str, help="The output directory where the model checkpoints will be written.")
126-
parser.add_argument("--max_seq_len", default=512, type=int, help="The maximum input sequence length. "
127-
"Sequences longer than this will be truncated, sequences shorter will be padded.")
115+
parser.add_argument("--max_seq_len", default=512, type=int, help="The maximum input sequence length. ")
128116
parser.add_argument("--num_epochs", default=100, type=int, help="Total number of training epochs to perform.")
129117
parser.add_argument("--seed", default=1000, type=int, help="Random seed for initialization")
130118
parser.add_argument("--logging_steps", default=10, type=int, help="The interval steps to logging.")

applications/speech_cmd_analysis/model.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

applications/speech_cmd_analysis/utils.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,16 @@
1414
# limitations under the License.
1515

1616
import json
17-
import time
1817
import math
1918
import random
20-
import numpy as np
21-
from tqdm import tqdm
22-
23-
from urllib.request import urlopen
24-
from urllib.request import Request
19+
import time
2520
from urllib.error import URLError
2621
from urllib.parse import urlencode
22+
from urllib.request import Request, urlopen
2723

24+
import numpy as np
2825
import paddle
29-
30-
MODEL_MAP = {
31-
"uie-base": {
32-
"encoding_model": "ernie-3.0-base-zh",
33-
"resource_file_urls": {
34-
"model_state.pdparams": "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_state.pdparams",
35-
"model_config.json": "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
36-
},
37-
},
38-
"uie-tiny": {
39-
"encoding_model": "ernie-3.0-medium-zh",
40-
"resource_file_urls": {
41-
"model_state.pdparams": "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_state.pdparams",
42-
"model_config.json": "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
43-
},
44-
},
45-
}
26+
from tqdm import tqdm
4627

4728

4829
def set_seed(seed):
@@ -83,12 +64,12 @@ def mandarin_asr_api(api_key, secret_key, audio_file, audio_format="wav"):
8364
result_str = urlopen(request).read()
8465
except URLError as error:
8566
print("token http response http code : " + str(error.code))
86-
result_str = err.read()
67+
result_str = error.read()
8768
result_str = result_str.decode()
8869

8970
result = json.loads(result_str)
9071
if "access_token" in result.keys() and "scope" in result.keys():
91-
if SCOPE and (not SCOPE in result["scope"].split(" ")):
72+
if SCOPE and (SCOPE not in result["scope"].split(" ")):
9273
raise ASRError("scope is not correct!")
9374
token = result["access_token"]
9475
else:
@@ -319,7 +300,7 @@ def convert_ext_examples(raw_examples, negative_ratio):
319300
entity_name_set = []
320301
predicate_set = []
321302

322-
print(f"Converting doccano data...")
303+
print("Converting doccano data...")
323304
with tqdm(total=len(raw_examples)) as pbar:
324305
for line in raw_examples:
325306
items = json.loads(line)
@@ -402,13 +383,13 @@ def convert_ext_examples(raw_examples, negative_ratio):
402383
relation_prompts.append(relation_prompt)
403384
pbar.update(1)
404385

405-
print(f"Adding negative samples for first stage prompt...")
386+
print("Adding negative samples for first stage prompt...")
406387
entity_examples = add_negative_example(entity_examples, texts, entity_prompts, entity_label_set, negative_ratio)
407388
if len(predicate_set) != 0:
408-
print(f"Constructing relation prompts...")
389+
print("Constructing relation prompts...")
409390
relation_prompt_set = construct_relation_prompt_set(entity_name_set, predicate_set)
410391

411-
print(f"Adding negative samples for second stage prompt...")
392+
print("Adding negative samples for second stage prompt...")
412393
relation_examples = add_negative_example(
413394
relation_examples, texts, relation_prompts, relation_prompt_set, negative_ratio
414395
)

0 commit comments

Comments
 (0)