@@ -104,7 +104,7 @@ def forward(
104104 if input_ids is not None :
105105 num_pad = paddle .sum ((input_ids == self .pad_token_id ).astype ("float32" ), axis = - 1 , keepdim = True )
106106 position_ids = F .relu (
107- paddle .expand_as (paddle .arange (end = inputs_shape [1 ], dtype = "int64 " ), inputs_shape ) - num_pad
107+ paddle .expand_as (paddle .arange (end = inputs_shape [1 ], dtype = "float32 " ), inputs_shape ) - num_pad
108108 ).astype ("int64" )
109109 else :
110110 logger .warning (
@@ -462,6 +462,7 @@ def forward(
462462 def prepare_fast_entry (self , kwargs ):
463463 from paddlenlp .ops import FasterMIRO , FasterUNIMOText
464464
465+ decoding_lib = kwargs .get ("decoding_lib" , None )
465466 use_fp16_decoding = kwargs .get ("use_fp16_decoding" , False )
466467 decode_strategy = kwargs .get ("decode_strategy" )
467468 if decode_strategy == "sampling" and kwargs .get ("top_k" ) != 0 and kwargs .get ("top_p" ) != 1 :
@@ -480,9 +481,9 @@ def prepare_fast_entry(self, kwargs):
480481 )
481482
482483 if getattr (self .encoder , "norm" , None ) is None :
483- self ._fast_entry = FasterUNIMOText (self , use_fp16_decoding = use_fp16_decoding ).forward
484+ self ._fast_entry = FasterUNIMOText (self , use_fp16_decoding = use_fp16_decoding , decoding_lib = decoding_lib ).forward
484485 else :
485- self ._fast_entry = FasterMIRO (self , use_fp16_decoding = use_fp16_decoding ).forward
486+ self ._fast_entry = FasterMIRO (self , use_fp16_decoding = use_fp16_decoding , decoding_lib = decoding_lib ).forward
486487 return self ._fast_entry
487488
488489 def adjust_logits_during_generation (self , logits ):
0 commit comments