Skip to content

Commit b3961f2

Browse files
authored
[Bug fixes] update input_spec for forward (#7648)
* update input_spec for forward * update opt casting
1 parent 942865f commit b3961f2

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

paddlenlp/generation/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import copy
18+
import inspect
1819
from typing import Union
1920

2021
import paddle
@@ -1203,10 +1204,13 @@ def sample(
12031204
return input_ids[:, origin_len:], scores
12041205

12051206
def _get_model_inputs_spec(self, dtype: str):
1206-
return {
1207+
spec = {
12071208
"input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
12081209
"attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
12091210
}
1211+
if "position_ids" in inspect.getfullargspec(self.forward).args:
1212+
spec["position_ids"] = paddle.static.InputSpec(shape=[None, None], dtype="int64")
1213+
return spec
12101214

12111215
def to_static(self, path: str, config: dict):
12121216
"""export generation model to static

paddlenlp/transformers/opt/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _expand_mask(mask, tgt_length):
8181
tgt_length = tgt_length if tgt_length is not None else src_length
8282

8383
expanded_mask = ~(paddle.cast(mask[:, None, None, :], "bool"))
84-
expanded_mask = paddle.cast(expanded_mask, dtype=paddle.float32)
84+
expanded_mask = paddle.cast(expanded_mask, dtype=paddle.get_default_dtype())
8585

8686
expanded_mask = expanded_mask.expand([batch_size, 1, tgt_length, src_length])
8787
expanded_mask = expanded_mask * float(finfo(paddle.get_default_dtype()).min)

0 commit comments

Comments
 (0)