Skip to content

Commit 8cd2d7c

Browse files
authored
model : add grok-2 support (#782)
Co-authored-by: firecoperana <firecoperana>
1 parent 18f0435 commit 8cd2d7c

File tree

10 files changed

+266
-77
lines changed

10 files changed

+266
-77
lines changed

common/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ struct gpt_params {
116116
float rope_freq_base = 0.0f; // RoPE base frequency
117117
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
118118
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
119-
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
120-
float yarn_beta_fast = 32.0f; // YaRN low correction dim
121-
float yarn_beta_slow = 1.0f; // YaRN high correction dim
119+
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
120+
float yarn_beta_fast = -1.0f; // YaRN low correction dim
121+
float yarn_beta_slow = -1.0f; // YaRN high correction dim
122122
int32_t yarn_orig_ctx = 0; // YaRN original context length
123123
float defrag_thold = -1.0f; // KV cache defragmentation threshold
124124

convert_hf_to_gguf.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
555555
# NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script
556556
# or pull the latest version of the model from Huggingface
557557
# don't edit the hashes manually!
558+
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
559+
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
560+
res = "grok-2"
558561
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
559562
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
560563
res = "llama-bpe"
@@ -1905,57 +1908,109 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
19051908
return tensors
19061909

19071910

1908-
@Model.register("GrokForCausalLM")
1911+
@Model.register("GrokForCausalLM", "Grok1ForCausalLM")
19091912
class GrokModel(Model):
19101913
model_arch = gguf.MODEL_ARCH.GROK
19111914

19121915
def set_vocab(self):
1913-
self._set_vocab_sentencepiece()
1916+
if (self.dir_model / 'tokenizer.model').is_file():
1917+
self._set_vocab_sentencepiece()
1918+
return
1919+
1920+
if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
1921+
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
1922+
sys.exit(1)
1923+
1924+
self._set_vocab_gpt2()
19141925

19151926
def __init__(self, *args, **kwargs):
19161927
super().__init__(*args, **kwargs)
19171928

19181929
def set_gguf_parameters(self):
19191930
super().set_gguf_parameters()
19201931

1921-
_experts: list[dict[str, Tensor]] | None = None
1932+
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
1933+
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
1934+
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
1935+
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
1936+
1937+
if (rope_dim := self.hparams.get("head_dim")) is None:
1938+
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
1939+
1940+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
1941+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
1942+
1943+
# Treat "original" as "yarn", seems to have been a mistake
1944+
if self.hparams.get("rope_type") in ("yarn", "original"):
1945+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
1946+
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
1947+
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
1948+
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
1949+
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
1950+
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
1951+
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
1952+
1953+
if temp_len := self.hparams.get("attn_temperature_len"):
1954+
self.gguf_writer.add_attn_temperature_length(temp_len)
1955+
1956+
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
1957+
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
1958+
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
1959+
1960+
_experts: list[dict[str, list[Tensor]]] | None = None
1961+
_cur_expert = ""
19221962

19231963
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
1964+
tensors: list[tuple[str, Tensor]] = []
1965+
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
1966+
1967+
if not is_expert:
1968+
tensors.append((self.map_tensor_name(name), data_torch))
1969+
19241970
# process the experts separately
1925-
if name.find(".moe.") != -1:
1971+
if is_expert or self._cur_expert:
19261972
n_experts = self.hparams["num_local_experts"]
19271973

19281974
assert bid is not None
19291975

19301976
if self._experts is None:
19311977
self._experts = [{} for _ in range(self.block_count)]
19321978

1933-
self._experts[bid][name] = data_torch
1934-
1935-
if len(self._experts[bid]) >= n_experts * 3:
1936-
tensors: list[tuple[str, Tensor]] = []
1979+
# concatenate split tensors
1980+
if name in self._experts[bid]:
1981+
self._cur_expert = name
1982+
self._experts[bid][name].append(data_torch)
1983+
return []
1984+
elif is_expert:
1985+
self._cur_expert = name
1986+
self._experts[bid][name] = [data_torch]
1987+
return []
1988+
else:
1989+
self._cur_expert = ""
19371990

1938-
# merge the experts into a single 3d tensor
1939-
for wid in ["linear", "linear_1", "linear_v"]:
1940-
datas: list[Tensor] = []
1991+
for bid in range(self.block_count):
1992+
if len(self._experts[bid]) >= n_experts * 3:
1993+
# merge the experts into a single 3d tensor
1994+
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
1995+
datas: list[Tensor] = []
19411996

1942-
for xid in range(n_experts):
1943-
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
1944-
datas.append(self._experts[bid][ename])
1945-
del self._experts[bid][ename]
1997+
for xid in range(n_experts):
1998+
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
1999+
if ename not in self._experts[bid]:
2000+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
2001+
tensor_list = self._experts[bid][ename]
2002+
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
2003+
del self._experts[bid][ename]
19462004

1947-
data_torch = torch.stack(datas, dim=0)
2005+
data_torch = torch.stack(datas, dim=0)
19482006

1949-
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
2007+
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
19502008

1951-
new_name = self.map_tensor_name(merged_name)
2009+
new_name = self.map_tensor_name(merged_name)
19522010

1953-
tensors.append((new_name, data_torch))
1954-
return tensors
1955-
else:
1956-
return []
2011+
yield (new_name, data_torch)
19572012

1958-
return [(self.map_tensor_name(name), data_torch)]
2013+
yield from tensors
19592014

19602015

19612016
@Model.register("DbrxForCausalLM")

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class TOKENIZER_TYPE(IntEnum):
9999
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2", },
100100
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902", },
101101
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890", },
102+
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
102103
]
103104

104105

gguf-py/gguf/constants.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class LLM:
9797
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
9898
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
9999
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
100+
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
100101

101102
class Attention:
102103
HEAD_COUNT = "{arch}.attention.head_count"
@@ -112,16 +113,22 @@ class Attention:
112113
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
113114
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
114115
SLIDING_WINDOW = "{arch}.attention.sliding_window"
116+
OUTPUT_SCALE = "{arch}.attention.output_scale"
117+
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
115118

116119
class Rope:
117-
DIMENSION_COUNT = "{arch}.rope.dimension_count"
118-
FREQ_BASE = "{arch}.rope.freq_base"
119-
SCALING_TYPE = "{arch}.rope.scaling.type"
120-
SCALING_FACTOR = "{arch}.rope.scaling.factor"
121-
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
122-
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
123-
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
124-
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
120+
DIMENSION_COUNT = "{arch}.rope.dimension_count"
121+
FREQ_BASE = "{arch}.rope.freq_base"
122+
SCALING_TYPE = "{arch}.rope.scaling.type"
123+
SCALING_FACTOR = "{arch}.rope.scaling.factor"
124+
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
125+
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
126+
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
127+
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
128+
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
129+
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
130+
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
131+
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
125132

126133
class Split:
127134
LLM_KV_SPLIT_NO = "split.no"
@@ -540,6 +547,7 @@ class MODEL_TENSOR(IntEnum):
540547
MODEL_TENSOR.FFN_GATE_EXP,
541548
MODEL_TENSOR.FFN_DOWN_EXP,
542549
MODEL_TENSOR.FFN_UP_EXP,
550+
MODEL_TENSOR.FFN_POST_NORM,
543551
MODEL_TENSOR.LAYER_OUT_NORM,
544552
],
545553
MODEL_ARCH.GPTNEOX: [

gguf-py/gguf/gguf_writer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,9 @@ def add_logit_scale(self, value: float) -> None:
656656
def add_attn_logit_softcapping(self, value: float) -> None:
657657
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
658658

659+
def add_router_logit_softcapping(self, value: float) -> None:
660+
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
661+
659662
def add_final_logit_softcapping(self, value: float) -> None:
660663
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
661664

@@ -701,6 +704,12 @@ def add_relative_attn_buckets_count(self, value: int) -> None:
701704
def add_sliding_window(self, value: int) -> None:
702705
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
703706

707+
def add_attn_output_scale(self, value: float) -> None:
708+
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
709+
710+
def add_attn_temperature_length(self, value: int) -> None:
711+
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
712+
704713
def add_pooling_type(self, value: PoolingType) -> None:
705714
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
706715

@@ -728,6 +737,18 @@ def add_rope_scaling_finetuned(self, value: bool) -> None:
728737
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
729738
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
730739

740+
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
741+
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
742+
743+
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
744+
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
745+
746+
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
747+
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
748+
749+
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
750+
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
751+
731752
def add_ssm_conv_kernel(self, value: int) -> None:
732753
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
733754

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TensorNameMap:
2424
"backbone.embedding", # mamba
2525
"backbone.embeddings", # mamba-hf
2626
"transformer.in_out_embed", # Grok
27+
"model.layers.{bid}.pre_attn_norm", # grok-2
2728
"embedding.word_embeddings", # chatglm
2829
"transformer.token_embeddings", # openelm
2930
"shared", # t5
@@ -202,6 +203,7 @@ class TensorNameMap:
202203
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
203204
"encoder.layers.{bid}.norm1", # nomic-bert
204205
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
206+
"model.layers.{bid}.post_attn_norm", # grok-2
205207
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
206208
),
207209

@@ -230,6 +232,7 @@ class TensorNameMap:
230232
"h.{bid}.ln_2", # gpt2
231233
"model.layers.{bid}.ffn_norm", # internlm2
232234
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
235+
"model.layers.{bid}.pre_moe_norm", # grok-2
233236
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
234237
"transformer.layers.{bid}.ffn_norm", # openelm
235238
),
@@ -242,6 +245,7 @@ class TensorNameMap:
242245
# Post feed-forward norm
243246
MODEL_TENSOR.FFN_POST_NORM: (
244247
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
248+
"model.layers.{bid}.post_moe_norm", # grok-2
245249
),
246250

247251
MODEL_TENSOR.FFN_GATE_INP: (

src/llama-arch.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ enum llm_kv {
9999
LLM_KV_LOGIT_SCALE,
100100
LLM_KV_DECODER_START_TOKEN_ID,
101101
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
102+
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
102103
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
103104
LLM_KV_SWIN_NORM,
104105
LLM_KV_RESCALE_EVERY_N_LAYERS,
@@ -123,7 +124,8 @@ enum llm_kv {
123124
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
124125
LLM_KV_ATTENTION_SLIDING_WINDOW,
125126
LLM_KV_ATTENTION_SCALE,
126-
127+
LLM_KV_ATTENTION_OUTPUT_SCALE,
128+
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
127129
LLM_KV_ROPE_DIMENSION_COUNT,
128130
LLM_KV_ROPE_FREQ_BASE,
129131
LLM_KV_ROPE_SCALE_LINEAR,
@@ -134,6 +136,11 @@ enum llm_kv {
134136
LLM_KV_ROPE_SCALING_FINETUNED,
135137
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
136138

139+
LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
140+
LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
141+
LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
142+
LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
143+
137144
LLM_KV_SPLIT_NO,
138145
LLM_KV_SPLIT_COUNT,
139146
LLM_KV_SPLIT_TENSORS_COUNT,

src/llama-vocab.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
433433
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
434434
};
435435
break;
436+
case LLAMA_VOCAB_PRE_TYPE_GROK_2:
437+
regex_exprs = {
438+
// original regex from tokenizer.json
439+
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
440+
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
441+
};
442+
break;
436443
default:
437444
// default regex for BPE tokenization pre-processing
438445
regex_exprs = {
@@ -1973,6 +1980,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
19731980
tokenizer_pre == "kimi-k2") {
19741981
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
19751982
clean_spaces = false;
1983+
}
1984+
else if (
1985+
tokenizer_pre == "grok-2") {
1986+
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
1987+
clean_spaces = false;
19761988
} else {
19771989
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
19781990
}

src/llama-vocab.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
4747
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
4848
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
4949
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
50+
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
5051
};
5152

5253
struct LLM_KV;

0 commit comments

Comments
 (0)