Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ std::vector<paddle::Tensor> UnifiedDecodingForward(
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int& max_len,
const float& beam_search_diversity_rate,
const int& unk_id,
const int& mask_id,
Expand Down Expand Up @@ -251,7 +251,7 @@ std::vector<std::vector<int64_t>> UnifiedDecodingInferShape(
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int& max_len,
const float& beam_search_diversity_rate,
const int& unk_id,
const int& mask_id,
Expand Down Expand Up @@ -397,7 +397,7 @@ PD_BUILD_OP(fusion_unified_decoding)
"num_layer: int",
"bos_id: int",
"eos_id: int",
"max_len: int64_t",
"max_len: int",
"beam_search_diversity_rate: float",
"unk_id: int",
"mask_id: int",
Expand Down
7 changes: 4 additions & 3 deletions paddlenlp/transformers/unimo/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
if input_ids is not None:
num_pad = paddle.sum((input_ids == self.pad_token_id).astype("float32"), axis=-1, keepdim=True)
position_ids = F.relu(
paddle.expand_as(paddle.arange(end=inputs_shape[1], dtype="int64"), inputs_shape) - num_pad
paddle.expand_as(paddle.arange(end=inputs_shape[1], dtype="float32"), inputs_shape) - num_pad
).astype("int64")
else:
logger.warning(
Expand Down Expand Up @@ -462,6 +462,7 @@
def prepare_fast_entry(self, kwargs):
from paddlenlp.ops import FasterMIRO, FasterUNIMOText

decoding_lib = kwargs.get("decoding_lib", None)

Check warning on line 465 in paddlenlp/transformers/unimo/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/unimo/modeling.py#L465

Added line #L465 was not covered by tests
use_fp16_decoding = kwargs.get("use_fp16_decoding", False)
decode_strategy = kwargs.get("decode_strategy")
if decode_strategy == "sampling" and kwargs.get("top_k") != 0 and kwargs.get("top_p") != 1:
Expand All @@ -480,9 +481,9 @@
)

if getattr(self.encoder, "norm", None) is None:
self._fast_entry = FasterUNIMOText(self, use_fp16_decoding=use_fp16_decoding).forward
self._fast_entry = FasterUNIMOText(self, use_fp16_decoding=use_fp16_decoding, decoding_lib=decoding_lib).forward

Check warning on line 484 in paddlenlp/transformers/unimo/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/unimo/modeling.py#L484

Added line #L484 was not covered by tests
else:
self._fast_entry = FasterMIRO(self, use_fp16_decoding=use_fp16_decoding).forward
self._fast_entry = FasterMIRO(self, use_fp16_decoding=use_fp16_decoding, decoding_lib=decoding_lib).forward

Check warning on line 486 in paddlenlp/transformers/unimo/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/unimo/modeling.py#L486

Added line #L486 was not covered by tests
return self._fast_entry

def adjust_logits_during_generation(self, logits):
Expand Down