Skip to content

Commit 950b401

Browse files
pwilkinCISC
andauthored
Apply suggestions from code review
Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent a387e36 commit 950b401

File tree

4 files changed

+8
-13
lines changed

4 files changed

+8
-13
lines changed

convert_hf_to_gguf.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2781,7 +2781,8 @@ def set_gguf_parameters(self):
27812781
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
27822782
num_heads = self.hparams["num_attention_heads"]
27832783
num_kv_heads = self.hparams["num_key_value_heads"]
2784-
head_dim = self.hparams["hidden_size"] // num_heads
2784+
if (head_dim := self.hparams.get("head_dim")) is None:
2785+
head_dim = self.hparams["hidden_size"] // num_heads
27852786

27862787
if "ernie." in name:
27872788
name = name.replace("ernie.", "model.")
@@ -2834,11 +2835,6 @@ def set_gguf_parameters(self):
28342835
if (shared_expert_intermediate_size := self.hparams.get('intermediate_size')) is not None and (num_key_value_heads := self.hparams.get('num_key_value_heads')) is not None:
28352836
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size // num_key_value_heads)
28362837

2837-
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
2838-
if "exps" in new_name:
2839-
return gguf.GGMLQuantizationType.F16
2840-
return super().tensor_force_quant(name, new_name, bid, n_dims)
2841-
28422838
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
28432839
# Modify correction bias name as in DeepseekV2
28442840
if name.endswith("e_score_correction_bias"):
@@ -2863,7 +2859,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
28632859
return []
28642860

28652861
# process the experts separately
2866-
if name.find("experts.") != -1 and name.find("shared") == -1:
2862+
if name.find("mlp.experts") != -1:
28672863
n_experts = self.hparams["moe_num_experts"]
28682864
assert bid is not None
28692865

gguf-py/gguf/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ class MODEL_TENSOR(IntEnum):
678678
MODEL_ARCH.DOTS1: "dots1",
679679
MODEL_ARCH.ARCEE: "arcee",
680680
MODEL_ARCH.ERNIE4_5: "ernie4_5",
681-
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5_moe",
681+
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
682682
MODEL_ARCH.FALCON_H1: "falcon-h1",
683683
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
684684
MODEL_ARCH.SMOLLM3: "smollm3",

src/llama-arch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8181
{ LLM_ARCH_DOTS1, "dots1" },
8282
{ LLM_ARCH_ARCEE, "arcee" },
8383
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
84-
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5_moe" },
84+
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
8585
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
8686
{ LLM_ARCH_SMOLLM3, "smollm3" },
8787
{ LLM_ARCH_LFM2, "lfm2" },

src/llama-model.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8365,6 +8365,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
83658365

83668366
ggml_tensor * inp_out_ids = build_inp_out_ids();
83678367

8368+
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0");
83688369
for (int il = 0; il < n_layer; ++il) {
83698370
ggml_tensor * inpSA = inpL;
83708371
// norm
@@ -8403,17 +8404,15 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
84038404
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
84048405
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
84058406

8406-
const float freq_base_l = model.get_rope_freq_base (cparams, il);
8407-
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
84088407
Qcur = ggml_rope_ext(
84098408
ctx0, Qcur, inp_pos, nullptr,
8410-
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
8409+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
84118410
ext_factor, attn_factor, beta_fast, beta_slow
84128411
);
84138412

84148413
Kcur = ggml_rope_ext(
84158414
ctx0, Kcur, inp_pos, nullptr,
8416-
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
8415+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
84178416
ext_factor, attn_factor, beta_fast, beta_slow
84188417
);
84198418

0 commit comments

Comments
 (0)