Skip to content

Commit 77a1111

Browse files
add func get_special_tokens_mask() (#2875)
* add func get_special_tokens_mask() instead of using that in superclass which will lead to conflict with func build_inputs_with_special_tokens() * add func get_special_tokens_mask() instead of using that in superclass which will lead to conflict with func build_inputs_with_special_tokens() Co-authored-by: Guo Sheng <[email protected]>
1 parent 112830b commit 77a1111

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

paddlenlp/transformers/ernie/tokenizer.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,46 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
326326
_sep = [self.sep_token_id]
327327
return _cls + token_ids_0 + _sep + token_ids_1 + _sep
328328

329+
def get_special_tokens_mask(self,
330+
token_ids_0,
331+
token_ids_1=None,
332+
already_has_special_tokens=False):
333+
r"""
334+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
335+
special tokens using the tokenizer ``encode`` methods.
336+
337+
Args:
338+
token_ids_0 (List[int]):
339+
List of ids of the first sequence.
340+
token_ids_1 (List[int], optinal):
341+
Optional second list of IDs for sequence pairs.
342+
Defaults to `None`.
343+
already_has_special_tokens (str, optional):
344+
Whether or not the token list is already formatted with special tokens for the model.
345+
Defaults to `False`.
346+
347+
Returns:
348+
List[int]:
349+
The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
350+
"""
351+
352+
if already_has_special_tokens:
353+
if token_ids_1 is not None:
354+
raise ValueError(
355+
"You should not supply a second sequence if the provided sequence of "
356+
"ids is already formatted with special tokens for the model."
357+
)
358+
return list(
359+
map(
360+
lambda x: 1
361+
if x in [self.sep_token_id, self.cls_token_id] else 0,
362+
token_ids_0))
363+
364+
if token_ids_1 is not None:
365+
return [1] + ([0] * len(token_ids_0)) + [1] + (
366+
[0] * len(token_ids_1)) + [1]
367+
return [1] + ([0] * len(token_ids_0)) + [1]
368+
329369
def build_offset_mapping_with_special_tokens(self,
330370
offset_mapping_0,
331371
offset_mapping_1=None):

0 commit comments

Comments
 (0)