Skip to content

Commit 920e1e8

Browse files
authored
Fix dtype bug of text similarity (#3188)
1 parent 5401f01 commit 920e1e8

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

paddlenlp/taskflow/text_similarity.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,11 @@ def _preprocess(self, inputs):
220220
]
221221
if ("rocketqa" in self.model_name):
222222
batchify_fn = lambda samples, fn=Tuple(
223-
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # input ids
224-
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id
225-
), # token type ids
223+
Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64'
224+
), # input ids
225+
Pad(axis=0,
226+
pad_val=self._tokenizer.pad_token_type_id,
227+
dtype='int64'), # token type ids
226228
): [data for data in fn(samples)]
227229
else:
228230
batchify_fn = lambda samples, fn=Tuple(

0 commit comments

Comments
 (0)