Skip to content

Commit 551a64f

Browse files
authored
add grok-2 support
1 parent c9a24fb commit 551a64f

File tree

8 files changed

+107
-12
lines changed

8 files changed

+107
-12
lines changed

convert_hf_to_gguf.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2635,19 +2635,82 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26352635
yield (new_name, data_torch)
26362636

26372637

2638-
@ModelBase.register("GrokForCausalLM")
2638+
@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
26392639
class GrokModel(TextModel):
26402640
model_arch = gguf.MODEL_ARCH.GROK
26412641

26422642
def set_vocab(self):
2643-
self._set_vocab_sentencepiece()
2643+
if (self.dir_model / 'tokenizer.model').is_file():
2644+
self._set_vocab_sentencepiece()
2645+
return
2646+
2647+
tokenizer_path = self.dir_model / 'tokenizer.tok.json'
2648+
with open(tokenizer_path, "r", encoding="utf-8") as f:
2649+
tokenizer = json.load(f)
2650+
2651+
vocab_size = tokenizer["vocab_size"]
2652+
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
2653+
scores: list[float] = [-10000.0] * vocab_size
2654+
toktypes: list[int] = [gguf.TokenType.UNUSED] * vocab_size
2655+
2656+
def decode_grok_token(token: dict, toktype: gguf.TokenType) -> tuple[gguf.TokenType, int, str]:
2657+
tokid = token["token"]
2658+
tokb = token["bytes"]
2659+
try:
2660+
tokc = bytes(tokb).decode("utf-8")
2661+
except:
2662+
tokc = None
2663+
if len(tokb) == 1 or not tokc:
2664+
return gguf.TokenType.BYTE, tokid, "<0x{:02X}>".format(tokb[0])
2665+
else:
2666+
return toktype, tokid, tokc
2667+
2668+
for token in tokenizer["special_tokens"]:
2669+
toktype, tokid, tokc = decode_grok_token(token, gguf.TokenType.CONTROL)
2670+
tokens[tokid] = tokc
2671+
toktypes[tokid] = toktype
2672+
scores[tokid] = 0.0
2673+
2674+
score = -0.0
2675+
for token in tokenizer["regular_tokens"]:
2676+
toktype, tokid, tokc = decode_grok_token(token, gguf.TokenType.NORMAL)
2677+
tokens[tokid] = tokc
2678+
toktypes[tokid] = toktype
2679+
scores[tokid] = score
2680+
score -= 1.0
2681+
2682+
self.gguf_writer.add_tokenizer_model("llama")
2683+
self.gguf_writer.add_tokenizer_pre("default")
2684+
self.gguf_writer.add_token_list(tokens)
2685+
self.gguf_writer.add_token_scores(scores)
2686+
self.gguf_writer.add_token_types(toktypes)
2687+
2688+
self.gguf_writer.add_add_bos_token(False)
2689+
2690+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
2691+
special_vocab.special_token_ids["pad"] = 0
2692+
special_vocab.special_token_ids["sep"] = 1
2693+
special_vocab.special_token_ids["eos"] = 2
2694+
special_vocab.add_to_gguf(self.gguf_writer)
26442695

26452696
def __init__(self, *args, **kwargs):
26462697
super().__init__(*args, **kwargs)
26472698

26482699
def set_gguf_parameters(self):
26492700
super().set_gguf_parameters()
26502701

2702+
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
2703+
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
2704+
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
2705+
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
2706+
2707+
if (rope_dim := self.hparams.get("head_dim")) is None:
2708+
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
2709+
2710+
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
2711+
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
2712+
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
2713+
26512714
_experts: list[dict[str, Tensor]] | None = None
26522715

26532716
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class LLM:
110110
LOGIT_SCALE = "{arch}.logit_scale"
111111
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
112112
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
113+
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
113114
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
114115
SWIN_NORM = "{arch}.swin_norm"
115116
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
@@ -145,6 +146,7 @@ class Attention:
145146
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
146147
SLIDING_WINDOW = "{arch}.attention.sliding_window"
147148
SCALE = "{arch}.attention.scale"
149+
OUTPUT_SCALE = "{arch}.attention.output_scale"
148150
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
149151
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
150152
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"

gguf-py/gguf/gguf_writer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,9 @@ def add_logit_scale(self, value: float) -> None:
730730
def add_attn_logit_softcapping(self, value: float) -> None:
731731
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
732732

733+
def add_router_logit_softcapping(self, value: float) -> None:
734+
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
735+
733736
def add_final_logit_softcapping(self, value: float) -> None:
734737
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
735738

@@ -826,6 +829,9 @@ def add_sliding_window(self, value: int) -> None:
826829
def add_attention_scale(self, value: float) -> None:
827830
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
828831

832+
def add_attn_output_scale(self, value: float) -> None:
833+
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
834+
829835
def add_pooling_type(self, value: PoolingType) -> None:
830836
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
831837

src/llama-arch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
135135
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
136136
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
137137
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
138+
{ LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" },
138139
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
139140
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
140141
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
@@ -165,6 +166,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
165166
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
166167
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
167168
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
169+
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
168170
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
169171
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
170172

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ enum llm_kv {
139139
LLM_KV_LOGIT_SCALE,
140140
LLM_KV_DECODER_START_TOKEN_ID,
141141
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
142+
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
142143
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
143144
LLM_KV_SWIN_NORM,
144145
LLM_KV_RESCALE_EVERY_N_LAYERS,
@@ -169,6 +170,7 @@ enum llm_kv {
169170
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
170171
LLM_KV_ATTENTION_SLIDING_WINDOW,
171172
LLM_KV_ATTENTION_SCALE,
173+
LLM_KV_ATTENTION_OUTPUT_SCALE,
172174
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
173175
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
174176

src/llama-graph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,13 +1290,13 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12901290

12911291
if (arch == LLM_ARCH_GROK) {
12921292
// need to do the following:
1293-
// multiply by attn_output_multiplyer of 0.08838834764831845
1293+
// multiply by attn_output_multiplier
12941294
// and then :
12951295
// kq = 30 * tanh(kq / 30)
12961296
// before the softmax below
12971297

1298-
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
1299-
kq = ggml_scale(ctx0, kq, 30);
1298+
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1299+
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
13001300
}
13011301

13021302
if (hparams.attn_soft_cap) {

src/llama-hparams.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ struct llama_hparams {
8080
float f_norm_rms_eps;
8181
float f_norm_group_eps;
8282

83-
float f_attn_logit_softcapping = 50.0f;
84-
float f_final_logit_softcapping = 30.0f;
83+
float f_attn_logit_softcapping = 50.0f;
84+
float f_router_logit_softcapping = 30.0f;
85+
float f_final_logit_softcapping = 30.0f;
8586

8687
// for RWKV
8788
uint32_t rescale_every_n_layers = 0;
@@ -133,6 +134,7 @@ struct llama_hparams {
133134
float f_residual_scale = 0.0f;
134135
float f_embedding_scale = 0.0f;
135136
float f_attention_scale = 0.0f;
137+
float f_attn_out_scale = 0.0f;
136138

137139
bool causal_attn = true;
138140
bool use_alibi = false;

src/llama-model.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,22 @@ void llama_model::load_hparams(llama_model_loader & ml) {
684684
} break;
685685
case LLM_ARCH_GROK:
686686
{
687+
// defaults for old GGUFs
688+
hparams.f_logit_scale = 0.5773502691896257f;
689+
hparams.f_embedding_scale = 78.38367176906169f;
690+
hparams.f_attn_out_scale = 0.08838834764831845f;
691+
hparams.f_attn_logit_softcapping = 30.0f;
692+
hparams.f_router_logit_softcapping = 30.0f;
693+
// no final_logit_softcapping in grok-1
694+
hparams.f_final_logit_softcapping = 0.0f;
695+
687696
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
697+
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
698+
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
699+
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
700+
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
701+
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
702+
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
688703

689704
switch (hparams.n_layer) {
690705
case 64: type = LLM_TYPE_314B; break;
@@ -6886,8 +6901,7 @@ struct llm_build_grok : public llm_graph_context {
68866901

68876902
inpL = build_inp_embd(model.tok_embd);
68886903

6889-
// multiply by embedding_multiplier_scale of 78.38367176906169
6890-
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
6904+
inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale);
68916905

68926906
// inp_pos - contains the positions
68936907
ggml_tensor * inp_pos = build_inp_pos();
@@ -7024,10 +7038,14 @@ struct llm_build_grok : public llm_graph_context {
70247038
// lm_head
70257039
cur = build_lora_mm(model.output, cur);
70267040

7027-
// Grok
7028-
// multiply logits by output_multiplier_scale of 0.5773502691896257
7041+
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
70297042

7030-
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
7043+
// final logit soft-capping
7044+
if (hparams.f_final_logit_softcapping) {
7045+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
7046+
cur = ggml_tanh(ctx0, cur);
7047+
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
7048+
}
70317049

70327050
cb(cur, "result_output", -1);
70337051
res->t_logits = cur;

0 commit comments

Comments
 (0)