Skip to content

Commit 79379af

Browse files
committed
Add tokenizer convert logic for Mistral
#367
1 parent 0c9edef commit 79379af

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

loader.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,49 @@ def gguf_tokenizer_loader(path, temb_shape):
327327
del reader
328328
return torch.ByteTensor(list(spm.SerializeToString()))
329329

330+
def gguf_tekken_tokenizer_loader(path, temb_shape):
331+
# convert ggml (hf) tokenizer metadata to tekken/comfy data
332+
logging.info("Attempting to recreate tekken tokenizer from GGUF file metadata...")
333+
import json
334+
import base64
335+
from transformers.convert_slow_tokenizer import bytes_to_unicode
336+
337+
reader = gguf.GGUFReader(path)
338+
339+
model_str = get_field(reader, "tokenizer.ggml.model", str)
340+
if model_str == "gpt2":
341+
if temb_shape == (131072, 5120): # probably Mistral
342+
data = {
343+
"config": {"num_vocab_tokens": 150000, "default_vocab_size": 131072},
344+
"vocab": [],
345+
"special_tokens": [],
346+
}
347+
else:
348+
raise NotImplementedError("Unknown model, can't set tokenizer!")
349+
else:
350+
raise NotImplementedError("Unknown model, can't set tokenizer!")
351+
352+
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
353+
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)
354+
355+
decoder = {v: k for k, v in bytes_to_unicode().items()}
356+
for idx, (token, toktype) in enumerate(zip(tokens, toktypes)):
357+
if toktype == 3:
358+
data["special_tokens"].append(
359+
{'rank': idx, 'token_str': token, 'is_control': True}
360+
)
361+
else:
362+
tok = bytes([decoder[char] for char in token])
363+
data["vocab"].append({
364+
"rank": len(data["vocab"]),
365+
"token_bytes": base64.b64encode(tok).decode("ascii"),
366+
"token_str": tok.decode("utf-8", errors="replace") # ?
367+
})
368+
369+
logging.info(f"Created tekken tokenizer with vocab size of {len(data['vocab'])} (+{len(data['special_tokens'])})")
370+
del reader
371+
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))
372+
330373
def gguf_clip_loader(path):
331374
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
332375
if arch in {"t5", "t5encoder"}:
@@ -342,12 +385,15 @@ def gguf_clip_loader(path):
342385
# TODO: pass model_options["vocab_size"] to loader somehow
343386
temb_key = "token_embd.weight"
344387
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
388+
if arch == "llama" and sd[temb_key].shape == (131072, 5120):
389+
# non-standard Comfy-Org tokenizer
390+
sd["tekken_model"] = gguf_tekken_tokenizer_loader(path, sd[temb_key].shape)
345391
# See note above for T5.
346392
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
347393
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
348394
sd = sd_map_replace(sd, LLAMA_SD_MAP)
349395
if arch == "llama":
350-
sd = llama_permute(sd, 32, 8) # L3
396+
sd = llama_permute(sd, 32, 8) # L3 / Mistral
351397
if arch == "qwen2vl":
352398
vsd = gguf_mmproj_loader(path)
353399
sd.update(vsd)

0 commit comments

Comments
 (0)