Skip to content

Commit 6b3f775

Browse files
authored
concatenate split tensors and add more metadata
1 parent 92266e9 commit 6b3f775

File tree

7 files changed

+55
-17
lines changed

7 files changed

+55
-17
lines changed

convert_hf_to_gguf.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,27 +2717,56 @@ def set_gguf_parameters(self):
27172717
if (rope_dim := self.hparams.get("head_dim")) is None:
27182718
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
27192719

2720+
# Treat "original" as "yarn", seems to have been a mistake
2721+
if self.hparams.get("rope_type") in ("yarn", "original"):
2722+
# config.json values differ from standard, we may have to add metadata for these:
2723+
# extrapolation_factor = 1.0
2724+
# attn_factor = 1.0
2725+
# beta_fast = 8
2726+
# beta_slow = 1
2727+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2728+
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
2729+
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
2730+
2731+
if temp_len := self.hparams.get("attn_temperature_len"):
2732+
self.gguf_writer.add_attn_temperature_length(temp_len)
2733+
27202734
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
27212735
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
27222736
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
27232737

2724-
_experts: list[dict[str, Tensor]] | None = None
2738+
_experts: list[dict[str, list[Tensor]]] | None = None
2739+
_cur_expert = ""
27252740

27262741
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2742+
tensors: list[tuple[str, Tensor]] = []
2743+
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
2744+
2745+
if not is_expert:
2746+
tensors.append((self.map_tensor_name(name), data_torch))
2747+
27272748
# process the experts separately
2728-
if name.find(".moe.") != -1 or name.find(".block_sparse_moe.experts.") != -1:
2749+
if is_expert or self._cur_expert:
27292750
n_experts = self.hparams["num_local_experts"]
27302751

27312752
assert bid is not None
27322753

27332754
if self._experts is None:
27342755
self._experts = [{} for _ in range(self.block_count)]
27352756

2736-
self._experts[bid][name] = data_torch
2757+
# concatenate split tensors
2758+
if name in self._experts[bid]:
2759+
self._cur_expert = name
2760+
self._experts[bid][name].append(data_torch)
2761+
return []
2762+
elif is_expert:
2763+
self._cur_expert = name
2764+
self._experts[bid][name] = [data_torch]
2765+
return []
2766+
else:
2767+
self._cur_expert = ""
27372768

27382769
if len(self._experts[bid]) >= n_experts * 3:
2739-
tensors: list[tuple[str, Tensor]] = []
2740-
27412770
# merge the experts into a single 3d tensor
27422771
for wid in [("linear", "w1"), ("linear_1", "w2"), ("linear_v", "w3")]:
27432772
datas: list[Tensor] = []
@@ -2746,7 +2775,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27462775
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
27472776
if ename not in self._experts[bid]:
27482777
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
2749-
datas.append(self._experts[bid][ename])
2778+
tensor_list = self._experts[bid][ename]
2779+
datas.append(torch.hstack(tensor_list) if len(tensor_list) > 1 else tensor_list[0])
27502780
del self._experts[bid][ename]
27512781

27522782
data_torch = torch.stack(datas, dim=0)
@@ -2756,11 +2786,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27562786
new_name = self.map_tensor_name(merged_name)
27572787

27582788
tensors.append((new_name, data_torch))
2759-
return tensors
2760-
else:
2761-
return []
27622789

2763-
return [(self.map_tensor_name(name), data_torch)]
2790+
return tensors
27642791

27652792

27662793
@ModelBase.register("DbrxForCausalLM")

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class Attention:
147147
SLIDING_WINDOW = "{arch}.attention.sliding_window"
148148
SCALE = "{arch}.attention.scale"
149149
OUTPUT_SCALE = "{arch}.attention.output_scale"
150+
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
150151
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
151152
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
152153
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,9 @@ def add_attention_scale(self, value: float) -> None:
832832
def add_attn_output_scale(self, value: float) -> None:
833833
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
834834

835+
def add_attn_temperature_length(self, value: float) -> None:
836+
self.add_float32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
837+
835838
def add_pooling_type(self, value: PoolingType) -> None:
836839
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
837840

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
167167
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
168168
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
169169
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
170+
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
170171
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
171172
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
172173

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ enum llm_kv {
171171
LLM_KV_ATTENTION_SLIDING_WINDOW,
172172
LLM_KV_ATTENTION_SCALE,
173173
LLM_KV_ATTENTION_OUTPUT_SCALE,
174+
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
174175
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
175176
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
176177

src/llama-hparams.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ struct llama_hparams {
134134
float f_residual_scale = 0.0f;
135135
float f_embedding_scale = 0.0f;
136136
float f_attention_scale = 0.0f;
137+
138+
// grok-2
137139
float f_attn_out_scale = 0.0f;
140+
float f_attn_temp_len = 0.0f;
138141

139142
bool causal_attn = true;
140143
bool use_alibi = false;

src/llama-model.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -693,13 +693,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
693693
// no final_logit_softcapping in grok-1
694694
hparams.f_final_logit_softcapping = 0.0f;
695695

696-
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
697-
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
698-
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
699-
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
700-
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
701-
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
702-
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
696+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
697+
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
698+
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
699+
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
700+
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
701+
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
702+
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
703+
704+
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.f_attn_temp_len, false);
703705

704706
switch (hparams.n_layer) {
705707
case 64: type = LLM_TYPE_314B; break;

0 commit comments

Comments
 (0)