Skip to content

Commit adadf10

Browse files
committed
add new albert and t5 models
1 parent 873fa29 commit adadf10

File tree

25 files changed

+13387
-0
lines changed

25 files changed

+13387
-0
lines changed

graph_net/test/nlp_model_getter.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,93 @@ def get_xlnet_model_and_inputs(model_name, text, dtype):
154154
enc["attention_mask"] = (input_ids != pad_id).astype("int64")
155155

156156
return model, enc
157+
158+
159+
def get_t5_model_and_inputs(model_name, text, dtype):
160+
import paddle
161+
from paddlenlp.transformers import T5ForConditionalGeneration, T5Tokenizer
162+
163+
# 1) 分词器(先建 tokenizer 方便取 pad/eos id)
164+
tokenizer = T5Tokenizer.from_pretrained(model_name)
165+
166+
# 2) 编码输入(支持单条或批量 text)
167+
enc = tokenizer(
168+
text,
169+
return_tensors="pd",
170+
padding=True,
171+
truncation=True,
172+
max_length=512,
173+
)
174+
175+
# 补 attention_mask(pad 处为 0,其他为 1)
176+
if "attention_mask" not in enc:
177+
input_ids = enc["input_ids"]
178+
attn_mask = (input_ids != tokenizer.pad_token_id).astype("int64")
179+
enc["attention_mask"] = attn_mask
180+
181+
# 构造 decoder_input_ids:
182+
# T5 以 pad_token_id 作为 decoder_start_token_id
183+
batch_size = enc["input_ids"].shape[0]
184+
decoder_input_ids = paddle.full(
185+
shape=[batch_size, 1],
186+
fill_value=tokenizer.pad_token_id,
187+
dtype="int64",
188+
)
189+
190+
# 3) 加载模型
191+
model = T5ForConditionalGeneration.from_pretrained(model_name)
192+
if dtype == "float16":
193+
model = model.astype(paddle.float16)
194+
model.eval()
195+
196+
# 4) 组装喂给模型的输入
197+
inputs = {
198+
"input_ids": enc["input_ids"],
199+
"attention_mask": enc["attention_mask"],
200+
"decoder_input_ids": decoder_input_ids,
201+
}
202+
return model, inputs
203+
204+
205+
def get_albert_model_and_inputs(model_name, text, dtype):
206+
"""
207+
加载 ALBERT backbone(AlbertModel)并构造输入。
208+
- model_name 例如: "albert-base-v2", "albert-xxlarge-v1"(PaddleNLP 内置名称)
209+
- dtype: "float32" 或 "float16"
210+
返回: (model, inputs_dict)
211+
"""
212+
import paddle
213+
from paddlenlp.transformers import AlbertConfig, AlbertModel, AlbertTokenizer
214+
215+
# 1) 读取配置(不触发权重下载)
216+
config = AlbertConfig.from_pretrained(model_name)
217+
218+
# 2) 模型
219+
# 若你只需要网络结构,可改成: model = AlbertModel(config)
220+
model = AlbertModel(config)
221+
if dtype == "float16":
222+
model = model.astype(paddle.float16)
223+
model.eval()
224+
225+
# 3) 分词器
226+
tokenizer = AlbertTokenizer.from_pretrained(model_name)
227+
228+
# 若无 pad_token,则回退到 unk_token(ALBERT 没有 eos_token,别设 pad=eos)
229+
if tokenizer.pad_token is None:
230+
tokenizer.pad_token = tokenizer.unk_token
231+
232+
# 4) 构造输入(支持 str 或 List[str])
233+
enc = tokenizer(
234+
text,
235+
return_tensors="pd",
236+
padding=True,
237+
truncation=True,
238+
max_length=512,
239+
)
240+
241+
# 显式补 attention_mask(pad 处为 0)
242+
if "attention_mask" not in enc:
243+
input_ids = enc["input_ids"]
244+
enc["attention_mask"] = (input_ids != tokenizer.pad_token_id).astype("int64")
245+
246+
return model, enc
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"framework": "paddle",
3+
"model_name": "albert-base-v1",
4+
"num_devices_required": 1,
5+
"num_nodes_required": 1
6+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
class Program_weight_tensor_data_0:
2+
name = "data_0"
3+
shape = [1, 21]
4+
dtype = "int64"
5+
data = [
6+
2,
7+
10975,
8+
15,
9+
51,
10+
204,
11+
25,
12+
1909,
13+
9,
14+
31,
15+
589,
16+
2477,
17+
88,
18+
370,
19+
816,
20+
2761,
21+
17,
22+
66,
23+
2607,
24+
18,
25+
9,
26+
3,
27+
]
28+
29+
30+
class Program_weight_tensor_data_1:
31+
name = "data_1"
32+
shape = [1, 21]
33+
dtype = "int64"
34+
data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
35+
36+
37+
class Program_weight_tensor_data_2:
38+
name = "data_2"
39+
shape = [1, 21]
40+
dtype = "int64"
41+
data = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

0 commit comments

Comments
 (0)