Skip to content

Commit d1df979

Browse files
joerundenjhill
authored andcommitted
✨ Add CLI to convert tokenizer
This PR adds a new convert-to-fast-tokenizer command to the cli to add tokenizer.json files to models. This is also invoked on download-weights if the auto_convert flag is true. Signed-off-by: Joe Runde <[email protected]>
1 parent a06e38a commit d1df979

File tree

1 file changed

+31
-1
lines changed
  • server/text_generation_server

1 file changed

+31
-1
lines changed

server/text_generation_server/cli.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def download_weights(
9393
convert_to_safetensors(model_name, revision)
9494
elif not any(f.endswith(".safetensors") for f in files):
9595
print(".safetensors weights not found on hub, but were found locally. Remove them first to re-convert")
96-
96+
if auto_convert:
97+
convert_to_fast_tokenizer(model_name, revision)
9798

9899
@app.command()
99100
def convert_to_onnx(
@@ -193,5 +194,34 @@ def quantize(
193194
)
194195

195196

197+
@app.command()
198+
def convert_to_fast_tokenizer(
199+
model_name: str,
200+
revision: Optional[str] = None,
201+
output_path: Optional[str] = None,
202+
):
203+
from text_generation_server import utils
204+
205+
# Check for existing "tokenizer.json"
206+
model_path = utils.get_model_path(model_name, revision)
207+
208+
if os.path.exists(os.path.join(model_path, "tokenizer.json")):
209+
print(f"Model {model_name} already has a fast tokenizer")
210+
return
211+
212+
if output_path is not None:
213+
if not os.path.isdir(output_path):
214+
print(f"Output path {output_path} must exist and be a directory")
215+
return
216+
else:
217+
output_path = model_path
218+
219+
import transformers
220+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, revision=revision)
221+
tokenizer.save_pretrained(output_path)
222+
223+
print(f"Saved tokenizer to {output_path}")
224+
225+
196226
if __name__ == "__main__":
197227
app()

0 commit comments

Comments
 (0)