Skip to content

Commit 88ed893

Browse files
Allow SPieceTokenizer to load model from a byte string.
1 parent 334ba48 commit 88ed893

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import os
22

33
class SPieceTokenizer:
4+
add_eos = True
5+
46
@staticmethod
57
def from_pretrained(path):
68
return SPieceTokenizer(path)
79

810
def __init__(self, tokenizer_path):
911
import sentencepiece
10-
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
11-
self.end = self.tokenizer.eos_id()
12+
if isinstance(tokenizer_path, bytes):
13+
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
14+
else:
15+
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
1216

1317
def get_vocab(self):
1418
out = {}
@@ -18,5 +22,4 @@ def get_vocab(self):
1822

1923
def __call__(self, string):
2024
out = self.tokenizer.encode(string)
21-
out += [self.end]
2225
return {"input_ids": out}

0 commit comments

Comments
 (0)