Skip to content

Commit 590a77c

Browse files
committed
fix invalid user commit
1 parent 2d78898 commit 590a77c

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

paddlenlp/ops/fast_transformer/src/fusion_unified_decoding_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
6666
const int& num_layer,
6767
const int& bos_id,
6868
const int& eos_id,
69-
const int64_t& max_len,
69+
const int& max_len,
7070
const float& beam_search_diversity_rate,
7171
const int& unk_id,
7272
const int& mask_id,
@@ -251,7 +251,7 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
251251
const int& num_layer,
252252
const int& bos_id,
253253
const int& eos_id,
254-
const int64_t& max_len,
254+
const int& max_len,
255255
const float& beam_search_diversity_rate,
256256
const int& unk_id,
257257
const int& mask_id,
@@ -397,7 +397,7 @@ PD_BUILD_OP(fusion_unified_decoding)
397397
"num_layer: int",
398398
"bos_id: int",
399399
"eos_id: int",
400-
"max_len: int64_t",
400+
"max_len: int",
401401
"beam_search_diversity_rate: float",
402402
"unk_id: int",
403403
"mask_id: int",

paddlenlp/transformers/unimo/modeling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(
104104
if input_ids is not None:
105105
num_pad = paddle.sum((input_ids == self.pad_token_id).astype("float32"), axis=-1, keepdim=True)
106106
position_ids = F.relu(
107-
paddle.expand_as(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) - num_pad
107+
paddle.expand_as(paddle.arange(end=inputs_shape[1], dtype="float32"), inputs_shape) - num_pad
108108
).astype("int64")
109109
else:
110110
logger.warning(
@@ -462,6 +462,7 @@ def forward(
462462
def prepare_fast_entry(self, kwargs):
463463
from paddlenlp.ops import FasterMIRO, FasterUNIMOText
464464

465+
decoding_lib = kwargs.get("decoding_lib", None)
465466
use_fp16_decoding = kwargs.get("use_fp16_decoding", False)
466467
decode_strategy = kwargs.get("decode_strategy")
467468
if decode_strategy == "sampling" and kwargs.get("top_k") != 0 and kwargs.get("top_p") != 1:
@@ -480,9 +481,9 @@ def prepare_fast_entry(self, kwargs):
480481
)
481482

482483
if getattr(self.encoder, "norm", None) is None:
483-
self._fast_entry = FasterUNIMOText(self, use_fp16_decoding=use_fp16_decoding).forward
484+
self._fast_entry = FasterUNIMOText(self, use_fp16_decoding=use_fp16_decoding, decoding_lib=decoding_lib).forward
484485
else:
485-
self._fast_entry = FasterMIRO(self, use_fp16_decoding=use_fp16_decoding).forward
486+
self._fast_entry = FasterMIRO(self, use_fp16_decoding=use_fp16_decoding, decoding_lib=decoding_lib).forward
486487
return self._fast_entry
487488

488489
def adjust_logits_during_generation(self, logits):

0 commit comments

Comments
 (0)