Skip to content

Commit a6d6eaf

Browse files
committed
support 20b softmax, 4b no sliding window
1 parent efe27eb commit a6d6eaf

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

convert_hf_to_gguf.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6552,15 +6552,27 @@ def set_gguf_parameters(self):
65526552
if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
65536553
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size)
65546554
logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}")
6555+
if (self.hparams.get('moe_primary_router_apply_softmax')):
6556+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
6557+
else:
6558+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
65556559
# YaRN is not enabled by default
65566560
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
65576561
rope_scaling = self.hparams.get("rope_scaling") or {}
65586562
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
65596563
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
65606564
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
65616565
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
6566+
65626567
sliding_window = self.hparams.get("sliding_window")
6563-
self.gguf_writer.add_sliding_window(sliding_window)
6568+
sliding_window_layout = self.hparams.get("sliding_window_layout")
6569+
if sliding_window and sliding_window_layout:
6570+
for i in sliding_window_layout:
6571+
if i != 0:
6572+
self.gguf_writer.add_sliding_window(sliding_window)
6573+
break
6574+
elif sliding_window:
6575+
self.gguf_writer.add_sliding_window(sliding_window)
65646576

65656577
intermediate_size = self.hparams.get("ffn_hidden_size")
65666578
moe_intermediate_size = self.hparams.get("moe_ffn_hidden_size")

src/llama-model.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,17 +1551,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15511551
} break;
15521552
case LLM_ARCH_SMALLTHINKER:
15531553
{
1554-
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1555-
hparams.n_swa = 4096;
1556-
hparams.set_dense_start_swa_pattern(4);
1554+
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
15571555

1558-
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
1556+
if (found_swa && hparams.n_swa > 0) {
1557+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1558+
hparams.n_swa = 4096;
1559+
hparams.set_dense_start_swa_pattern(4);
1560+
} else {
1561+
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
1562+
hparams.n_no_rope_layer_step = hparams.n_layer;
1563+
}
15591564

1565+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
15601566
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1567+
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
15611568

15621569
switch (hparams.n_layer) {
1563-
default:
1564-
type = LLM_TYPE_UNKNOWN;
1570+
case 32: type = LLM_TYPE_4B; break;
1571+
case 52: type = LLM_TYPE_20B; break;
1572+
default: type = LLM_TYPE_UNKNOWN;
15651573
}
15661574
} break;
15671575
default: throw std::runtime_error("unsupported model architecture");
@@ -4854,6 +4862,7 @@ void llama_model::print_info() const {
48544862

48554863
if (arch == LLM_ARCH_SMALLTHINKER) {
48564864
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
4865+
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
48574866
}
48584867

48594868
vocab.print_info();
@@ -14736,7 +14745,12 @@ struct llm_build_smallthinker : public llm_graph_context{
1473614745
// inp_pos - contains the positions
1473714746
ggml_tensor * inp_pos = build_inp_pos();
1473814747

14739-
auto * inp_attn = build_attn_inp_kv_unified_iswa();
14748+
llm_graph_input_i * inp_attn = nullptr;
14749+
if (hparams.is_swa_any()) {
14750+
inp_attn = build_attn_inp_kv_unified_iswa();
14751+
} else {
14752+
inp_attn = build_attn_inp_kv_unified();
14753+
}
1474014754

1474114755
for (int il = 0; il < n_layer; ++il) {
1474214756
ggml_tensor * inpSA = inpL;
@@ -14747,7 +14761,11 @@ struct llm_build_smallthinker : public llm_graph_context{
1474714761
ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens]
1474814762
cb(logits, "ffn_moe_logits", il);
1474914763

14750-
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
14764+
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
14765+
probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
14766+
} else {
14767+
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
14768+
}
1475114769
cb(probs, "ffn_moe_probs", il);
1475214770
}
1475314771

@@ -14782,15 +14800,21 @@ struct llm_build_smallthinker : public llm_graph_context{
1478214800
cb(Qcur, "Qcur", il);
1478314801
cb(Kcur, "Kcur", il);
1478414802

14785-
cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur,
14786-
nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
14803+
if (hparams.is_swa_any()) {
14804+
cur = build_attn(static_cast<llm_graph_input_attn_kv_unified_iswa *>(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur,
14805+
nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
14806+
} else {
14807+
cur = build_attn(static_cast<llm_graph_input_attn_kv_unified *>(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur,
14808+
nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
14809+
}
1478714810
}
1478814811

1478914812
if (il == n_layer - 1) {
1479014813
// skip computing output for unused tokens
1479114814
ggml_tensor * inp_out_ids = build_inp_out_ids();
1479214815
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1479314816
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14817+
if (probs != nullptr) { probs = ggml_get_rows(ctx0, probs, inp_out_ids); }
1479414818
}
1479514819

1479614820
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);

0 commit comments

Comments
 (0)