Skip to content

Commit 873fa29

Browse files
committed
add new bart and xlnet models
1 parent b6a7783 commit 873fa29

File tree

25 files changed

+56509
-0
lines changed

25 files changed

+56509
-0
lines changed

graph_net/test/nlp_model_getter.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,50 @@ def get_skep_model_and_inputs(model_name, text, dtype):
107107
tokenizer = TokenizerClass.from_pretrained(model_name)
108108
inputs = tokenizer(text, return_tensors="pd")
109109
return model, inputs
110+
111+
112+
def get_bart_model_and_inputs(model_name, text, dtype):
113+
from paddlenlp.transformers import BartModel, BartTokenizer
114+
115+
model = BartModel.from_pretrained(model_name)
116+
model.eval()
117+
118+
tokenizer = BartTokenizer.from_pretrained(model_name)
119+
120+
inputs = tokenizer(
121+
text,
122+
return_tensors="pd",
123+
padding=True,
124+
truncation=True,
125+
max_length=512,
126+
)
127+
inputs.pop("token_type_ids", None)
128+
129+
return model, inputs
130+
131+
132+
def get_xlnet_model_and_inputs(model_name, text, dtype):
133+
import paddle
134+
from paddlenlp.transformers import XLNetModel, XLNetTokenizer, XLNetConfig
135+
136+
config = XLNetConfig.from_pretrained(model_name)
137+
model = XLNetModel(config)
138+
if dtype == "float16":
139+
model = model.astype(paddle.float16)
140+
model.eval()
141+
142+
tokenizer = XLNetTokenizer.from_pretrained(model_name)
143+
144+
enc = tokenizer(
145+
text,
146+
return_tensors="pd",
147+
padding=True,
148+
truncation=True,
149+
# max_length=512,
150+
)
151+
if "attention_mask" not in enc:
152+
input_ids = enc["input_ids"]
153+
pad_id = tokenizer.pad_token_id
154+
enc["attention_mask"] = (input_ids != pad_id).astype("int64")
155+
156+
return model, enc
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"framework": "paddle",
3+
"model_name": "bart-base",
4+
"num_devices_required": 1,
5+
"num_nodes_required": 1
6+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
class Program_weight_tensor_data_0:
2+
name = "data_0"
3+
shape = [1, 21]
4+
dtype = "int64"
5+
data = [
6+
0,
7+
31414,
8+
6,
9+
127,
10+
766,
11+
16,
12+
3045,
13+
4,
14+
38,
15+
524,
16+
2239,
17+
59,
18+
739,
19+
2777,
20+
3092,
21+
8,
22+
49,
23+
41885,
24+
4,
25+
1437,
26+
2,
27+
]

0 commit comments

Comments
 (0)