Skip to content

Commit 5ad0128

Browse files
ai-edge-botcopybara-github
authored andcommitted
Fix tokenizer_to_sentencepiece script broken partially
- DeepSeek tokenizer has vocab_file attr as None - SentencePieceProcessor.Encode() may return SentencePieceText proto instead of list of token IDs PiperOrigin-RevId: 719529425
1 parent 42eb2ef commit 5ad0128

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

ai_edge_torch/generative/tools/tokenizer_to_sentencepiece.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,17 @@ def _log_not_matched(
215215
)
216216

217217

218+
def _encode_by_spm(
219+
spm_tokenizer: spm.SentencePieceProcessor, string: str
220+
) -> List[int]:
221+
"""Encodes a string by the SentencePiece tokenizer."""
222+
ids = spm_tokenizer.Encode(string)
223+
if isinstance(ids, list):
224+
return ids
225+
# SentencePieceText
226+
return [p.id for p in ids.pieces]
227+
228+
218229
def _verify_spm_tokenizer(
219230
tokenizer: transformers.PreTrainedTokenizer,
220231
spm_tokenizer: spm.SentencePieceProcessor,
@@ -224,7 +235,7 @@ def _verify_spm_tokenizer(
224235
# as the token IDs encoded by the SentencePiece tokenizer.
225236
for string in _STRINGS_TO_VERIFY.value:
226237
ids_by_tokenizer = tokenizer.encode(string)
227-
ids_by_spm = spm_tokenizer.Encode(string)
238+
ids_by_spm = _encode_by_spm(spm_tokenizer, string)
228239
logging.info("String to verify: %s", string)
229240
logging.info("Token IDs by the oringal tokenizer: %s", ids_by_tokenizer)
230241
logging.info("Token IDs by the SentencePiece tokenizer: %s", ids_by_spm)
@@ -243,7 +254,7 @@ def _verify_spm_tokenizer(
243254
id_pair = random.sample(list(range(len(tokenizer.vocab))), 2)
244255
string = tokenizer.decode(id_pair)
245256
ids_by_tokenizer = tokenizer.encode(string)
246-
ids_by_spm = spm_tokenizer.Encode(string)
257+
ids_by_spm = _encode_by_spm(spm_tokenizer, string)
247258
if not _is_same_ids(ids_by_tokenizer, ids_by_spm):
248259
num_not_matched_strict += 1
249260
if _is_same_ids(ids_by_tokenizer, id_pair):
@@ -262,7 +273,7 @@ def _verify_spm_tokenizer(
262273

263274
def main(_):
264275
tokenizer = transformers.AutoTokenizer.from_pretrained(_CHECKPOINT.value)
265-
if hasattr(tokenizer, "vocab_file"):
276+
if hasattr(tokenizer, "vocab_file") and tokenizer.vocab_file:
266277
logging.info("vocab_file exists: %s", tokenizer.vocab_file)
267278
with open(tokenizer.vocab_file, "rb") as f:
268279
sp_model = spm_model.ModelProto.FromString(f.read())

0 commit comments

Comments
 (0)