-
Notifications
You must be signed in to change notification settings - Fork 5
Hunyuan tokenizer #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
5e78e88
d219580
616f4c7
0fd3930
b19ecae
245db15
8fd547b
34cc679
b20bd26
99d9e94
1221d94
5471f5a
46c8b70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -346,6 +346,8 @@ | |
data_qtype = gguf.GGMLQuantizationType.BF16 | ||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: | ||
data_qtype = gguf.GGMLQuantizationType.Q8_0 | ||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q4_0: | ||
data_qtype = gguf.GGMLQuantizationType.Q4_0 | ||
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0: | ||
data_qtype = gguf.GGMLQuantizationType.TQ1_0 | ||
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: | ||
|
@@ -607,7 +609,7 @@ | |
|
||
from transformers import AutoTokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model) | ||
assert max(tokenizer.vocab.values()) < vocab_size | ||
|
||
tokpre = self.get_vocab_base_pre(tokenizer) | ||
|
||
|
@@ -616,7 +618,7 @@ | |
|
||
added_tokens_decoder = tokenizer.added_tokens_decoder | ||
|
||
for i in range(vocab_size): | ||
if i not in reverse_vocab: | ||
tokens.append(f"[PAD{i}]") | ||
toktypes.append(gguf.TokenType.UNUSED) | ||
|
@@ -842,14 +844,14 @@ | |
def _set_vocab_none(self) -> None: | ||
self.gguf_writer.add_tokenizer_model("none") | ||
|
||
def _set_vocab_gpt2(self) -> None: | ||
def _set_vocab_gpt2(self, load_merges=True) -> None: | ||
tokens, toktypes, tokpre = self.get_vocab_base() | ||
self.gguf_writer.add_tokenizer_model("gpt2") | ||
self.gguf_writer.add_tokenizer_pre(tokpre) | ||
self.gguf_writer.add_token_list(tokens) | ||
self.gguf_writer.add_token_types(toktypes) | ||
|
||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) | ||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges) | ||
special_vocab.add_to_gguf(self.gguf_writer) | ||
|
||
def _set_vocab_qwen(self): | ||
|
@@ -6394,24 +6396,27 @@ | |
|
||
|
||
@ModelBase.register("HunYuanMoEV1ForCausalLM") | ||
class HunYuanMoEModel(LlamaModel): | ||
class HunYuanMoEModel(TextModel): | ||
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE | ||
undo_permute = False | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
# For handling tied embeddings | ||
self._tok_embd = None | ||
|
||
def set_vocab(self): | ||
self._set_vocab_gpt2() | ||
|
||
def get_vocab_base(self) -> tuple[list[str], list[int], str]: | ||
tokens: list[str] = [] | ||
toktypes: list[int] = [] | ||
|
||
""" | ||
A self-contained vocab implementation for the HunYuan tiktoken-based tokenizer. | ||
This method correctly generates tokens, types, and the required "fake" merges | ||
to satisfy the llama.cpp GGUF loader. | ||
""" | ||
from transformers import AutoTokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) | ||
|
||
# merge logic is copied from QwenModel, maybe incorrect | ||
# 1. Get the pre-tokenizer identifier hash | ||
tokpre = self.get_vocab_base_pre(tokenizer) | ||
|
||
# 2. Reverse-engineer the merges list from mergeable_ranks | ||
merges = [] | ||
vocab = {} | ||
mergeable_ranks = tokenizer.mergeable_ranks | ||
|
@@ -6420,97 +6425,107 @@ | |
if len(token) == 1: | ||
continue | ||
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) | ||
if len(merged) == 2: | ||
if len(merged) == 2: #todo this is an assert in Qwen, why? | ||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) | ||
self.gguf_writer.add_token_merges(merges) | ||
|
||
reverse_vocab = tokenizer.decoder | ||
assert max(reverse_vocab.keys()) < tokenizer.vocab_size | ||
|
||
tokpre = self.get_vocab_base_pre(tokenizer) | ||
added_vocab = tokenizer.get_added_vocab() | ||
|
||
added_tokens_decoder = tokenizer.added_tokens_decoder | ||
|
||
for i in range(tokenizer.vocab_size): | ||
# 3. Generate the tokens and toktypes lists | ||
vocab_size = self.hparams["vocab_size"] | ||
special_tokens = tokenizer.special_tokens | ||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} | ||
tokens: list[str] = [] | ||
toktypes: list[int] = [] | ||
for i in range(vocab_size): | ||
if i not in reverse_vocab: | ||
tokens.append(f"[PAD{i}]") | ||
toktypes.append(gguf.TokenType.UNUSED) | ||
else: | ||
token: str = reverse_vocab[i] | ||
if token in added_vocab: | ||
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. | ||
# To avoid unexpected issues - we make sure to normalize non-normalized tokens | ||
if not added_tokens_decoder[i].normalized: | ||
previous_token = token | ||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) | ||
if previous_token != token: | ||
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") | ||
|
||
if added_tokens_decoder[i].special or self.does_token_look_special(token): | ||
toktypes.append(gguf.TokenType.CONTROL) | ||
else: | ||
# NOTE: this was added for Gemma. | ||
# Encoding and decoding the tokens above isn't sufficient for this case. | ||
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces | ||
toktypes.append(gguf.TokenType.USER_DEFINED) | ||
token = reverse_vocab[i] | ||
tokens.append(token) | ||
if i in special_tokens.values(): | ||
toktypes.append(gguf.TokenType.CONTROL) | ||
else: | ||
toktypes.append(gguf.TokenType.NORMAL) | ||
tokens.append(token) | ||
|
||
return tokens, toktypes, tokpre | ||
# 4. Write all vocab-related fields to the GGUF writer | ||
self.gguf_writer.add_tokenizer_model("gpt2") | ||
self.gguf_writer.add_tokenizer_pre(tokpre) | ||
self.gguf_writer.add_token_list(tokens) | ||
self.gguf_writer.add_token_types(toktypes) | ||
self.gguf_writer.add_token_merges(merges) | ||
|
||
# 5. Add special tokens and chat templates | ||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) | ||
special_vocab.add_to_gguf(self.gguf_writer) | ||
# FIX for BOS token: Manually set the correct BOS token ID. | ||
self.gguf_writer.add_bos_token_id(127959) # <|bos|> | ||
|
||
def set_gguf_parameters(self): | ||
super().set_gguf_parameters() | ||
hparams = self.hparams | ||
|
||
self.gguf_writer.add_expert_count(self.hparams["num_experts"]) | ||
self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["intermediate_size"]) | ||
self.gguf_writer.add_expert_count(hparams["num_experts"]) | ||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) | ||
|
||
moe_intermediate_size = self.hparams["moe_intermediate_size"] | ||
moe_intermediate_size = hparams["moe_intermediate_size"] | ||
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) | ||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) | ||
|
||
moe_topk = self.hparams["moe_topk"] | ||
moe_topk = hparams["moe_topk"] | ||
assert all(topk == moe_topk[0] for topk in moe_topk) | ||
self.gguf_writer.add_expert_used_count(moe_topk[0]) | ||
|
||
moe_shared_expert = hparams["num_shared_expert"] | ||
assert all(n == moe_shared_expert[0] for n in moe_shared_expert) | ||
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) | ||
|
||
self.gguf_writer.add_qk_norm(hparams.get("use_qk_norm", True)) | ||
|
||
|
||
# Rope | ||
rope_scaling = hparams.get("rope_scaling", {}) | ||
if rope_scaling.get("type") == "dynamic": | ||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) | ||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) | ||
|
||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["max_position_embeddings"]) | ||
|
||
_experts: list[dict[str, Tensor]] | None = None | ||
|
||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
# process the experts separately | ||
if name == "model.embed_tokens.weight": | ||
self._tok_embd = data_torch.clone() | ||
if name == "lm_head.weight": | ||
if self.hparams.get("tie_word_embeddings", False): | ||
logger.info("Skipping tied output layer 'lm_head.weight'") | ||
return [] | ||
if name.find("mlp.experts") != -1: | ||
n_experts = self.hparams["num_experts"] | ||
assert bid is not None | ||
|
||
tensors: list[tuple[str, Tensor]] = [] | ||
|
||
if self._experts is None: | ||
self._experts = [{} for _ in range(self.block_count)] | ||
|
||
self._experts[bid][name] = data_torch | ||
|
||
if len(self._experts[bid]) >= n_experts * 3: | ||
# merge the experts into a single 3d tensor | ||
tensors: list[tuple[str, Tensor]] = [] | ||
for w_name in ["down_proj", "gate_proj", "up_proj"]: | ||
datas: list[Tensor] = [] | ||
|
||
for xid in range(n_experts): | ||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" | ||
datas.append(self._experts[bid][ename]) | ||
del self._experts[bid][ename] | ||
|
||
data_torch = torch.stack(datas, dim=0) | ||
|
||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" | ||
|
||
new_name = self.map_tensor_name(merged_name) | ||
|
||
tensors.append((new_name, data_torch)) | ||
|
||
return tensors | ||
else: | ||
return [] | ||
|
||
return [(self.map_tensor_name(name), data_torch)] | ||
|
||
def prepare_tensors(self): | ||
super().prepare_tensors() | ||
if self._experts is not None: | ||
experts = [k for d in self._experts for k in d.keys()] | ||
if len(experts) > 0: | ||
raise ValueError(f"Unprocessed experts: {experts}") | ||
|
||
###### CONVERSION LOGIC ###### | ||
|
||
|
||
|
@@ -6600,7 +6615,7 @@ | |
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", | ||
) | ||
parser.add_argument( | ||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", | ||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q4_0", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", | ||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", | ||
) | ||
parser.add_argument( | ||
|
@@ -6732,6 +6747,7 @@ | |
"f32": gguf.LlamaFileType.ALL_F32, | ||
"f16": gguf.LlamaFileType.MOSTLY_F16, | ||
"bf16": gguf.LlamaFileType.MOSTLY_BF16, | ||
"q4_0": gguf.LlamaFileType.MOSTLY_Q4_0, | ||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, | ||
"tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, | ||
"tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,6 +148,7 @@ class Attention: | |
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" | ||
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" | ||
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" | ||
QK_NORM = "{arch}.attention.qk_norm" | ||
|
||
|
||
class Rope: | ||
DIMENSION_COUNT = "{arch}.rope.dimension_count" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -792,6 +792,9 @@ def add_group_norm_groups(self, value: int) -> None: | |
def add_causal_attention(self, value: bool) -> None: | ||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) | ||
|
||
def add_qk_norm(self, value: bool) -> None: | ||
self.add_bool(Keys.Attention.QK_NORM.format(arch=self.arch), value) | ||
|
||
|
||
def add_q_lora_rank(self, length: int) -> None: | ||
self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -705,13 +705,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( | |
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] | ||
cb(weights, "ffn_moe_weights", il); | ||
|
||
if (arch == LLM_ARCH_HUNYUAN_MOE) { | ||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_expert_used, n_tokens] | ||
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [1, n_tokens] | ||
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); // [1, n_expert_used, n_tokens] | ||
cb(weights, "ffn_moe_weights_scaled", il); | ||
} | ||
|
||
Comment on lines
-708
to
-714
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha good catch, I didn't notice that I reinvented the |
||
if (norm_w) { | ||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can overwrite this in
hparams["bos_token_id"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if setting it in hparams would override the id that
gguf.SpecialVocab
reads from the config. I've left this as is for now, but that can be tested later