Skip to content

Commit 72c98b0

Browse files
authored
Merge pull request #1 from ggml-org/xsn/qwen3next_experiment
ngxson's fixes
2 parents 9832f29 + e83ef74 commit 72c98b0

File tree

6 files changed

+43
-31
lines changed

6 files changed

+43
-31
lines changed

convert_hf_to_gguf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3767,8 +3767,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
37673767
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
37683768
elif "conv1d" in name:
37693769
data_torch = data_torch.squeeze()
3770+
elif "q_proj.weight" in name:
3771+
q_proj, gate = data_torch.chunk(2, dim=0)
3772+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid), gate)
3773+
data_torch = q_proj
37703774

3771-
return Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)
3775+
yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)
37723776

37733777

37743778
@ModelBase.register("GPT2LMHeadModel")

gguf-py/gguf/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ class MODEL_TENSOR(IntEnum):
433433
ATTN_NORM_2 = auto()
434434
ATTN_OUT_NORM = auto()
435435
ATTN_POST_NORM = auto()
436+
ATTN_GATE = auto()
436437
ATTN_ROT_EMBD = auto()
437438
ATTN_SINKS = auto()
438439
FFN_GATE_INP = auto()
@@ -776,6 +777,7 @@ class MODEL_TENSOR(IntEnum):
776777
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
777778
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
778779
MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
780+
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
779781
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
780782
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
781783
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
@@ -1478,6 +1480,7 @@ class MODEL_TENSOR(IntEnum):
14781480
MODEL_TENSOR.ATTN_V,
14791481
MODEL_TENSOR.ATTN_OUT,
14801482
MODEL_TENSOR.ATTN_POST_NORM,
1483+
MODEL_TENSOR.ATTN_GATE,
14811484
MODEL_TENSOR.FFN_GATE_INP,
14821485
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
14831486
MODEL_TENSOR.FFN_UP_SHEXP,

src/llama-arch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
769769
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
770770
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
771771
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
772+
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
772773
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
773774
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
774775
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
@@ -2245,6 +2246,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
22452246
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22462247
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22472248
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2249+
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22482250
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22492251
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
22502252
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ enum llm_tensor {
381381
LLM_TENSOR_ATTN_Q_A_NORM,
382382
LLM_TENSOR_ATTN_KV_A_NORM,
383383
LLM_TENSOR_ATTN_SUB_NORM,
384+
LLM_TENSOR_ATTN_GATE,
384385
LLM_TENSOR_FFN_SUB_NORM,
385386
LLM_TENSOR_DEC_ATTN_NORM,
386387
LLM_TENSOR_DEC_ATTN_Q,

src/llama-model.cpp

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,9 +2435,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24352435
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
24362436
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
24372437

2438-
if ((i + 1) % 4 == 0) { // TODO: magic 4
2439-
// Attention layers
2440-
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_ff }, 0);
2438+
if (!hparams.is_recurrent(i)) {
2439+
// Attention layers
2440+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
24412441
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
24422442
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
24432443
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
@@ -2446,6 +2446,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24462446
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
24472447
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
24482448

2449+
// attn gate
2450+
layer.wq_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
2451+
24492452
} else {
24502453
// Linear attention (gated delta net) specific tensors
24512454
// Create tensors with calculated dimensions
@@ -2455,7 +2458,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24552458
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
24562459
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0);
24572460
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
2458-
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0);
2461+
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
24592462
}
24602463

24612464
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
@@ -19034,30 +19037,27 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1903419037
const int64_t n_embd_head,
1903519038
const int il) {
1903619039

19037-
// QKV projection with gating
19038-
ggml_tensor * qkv_g = build_lora_mm(model.layers[il].wq, cur);
19039-
cb(qkv_g, "qkv_g", il);
19040-
19041-
// Split into Q and gate
19042-
const int64_t n_embd_q = hparams.n_head(il) * n_embd_head;
19043-
ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens,
19044-
n_embd_head * sizeof(float), qkv_g->nb[1], 0);
19045-
ggml_tensor * gate = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens,
19046-
n_embd_head * sizeof(float), qkv_g->nb[1], n_embd_q * ggml_element_size(qkv_g));
19047-
19048-
// K and V projections
19049-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
19050-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
19040+
ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);
19041+
19042+
// compute Q and K and RoPE them
19043+
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
19044+
cb(Qcur, "Qcur", il);
19045+
19046+
struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
1905119047
cb(Kcur, "Kcur", il);
19048+
19049+
struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
1905219050
cb(Vcur, "Vcur", il);
1905319051

19054-
Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens);
19055-
Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
19056-
Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
19052+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
19053+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
19054+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
1905719055

1905819056
// Apply Q/K normalization
1905919057
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
1906019058
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
19059+
cb(Kcur, "Qcur_normed", il);
19060+
cb(Kcur, "Kcur_normed", il);
1906119061

1906219062
// Apply RoPE
1906319063
Qcur = ggml_rope_ext(
@@ -19081,7 +19081,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1908119081
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
1908219082

1908319083
// Apply gating
19084-
gate = ggml_reshape_2d(ctx0, ggml_cont(ctx0, gate), n_embd_q, n_tokens);
1908519084
cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
1908619085
cb(cur, "attn_gated", il);
1908719086

@@ -19184,16 +19183,10 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1918419183

1918519184
GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
1918619185

19187-
// Softplus would be nice...
19188-
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias
19189-
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
19190-
ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor
19191-
ggml_exp(ctx0, one_tensor); // make it a 1
19192-
ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias)
19193-
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
19186+
ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
1919419187
ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp()
1919519188
ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
19196-
ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus)
19189+
ggml_tensor * gate = ggml_scale(ctx0, gate_scaled, -1.0f); // - (A_log.exp() * softplus)
1919719190

1919819191
// Get convolution weights and bias
1919919192
ggml_tensor * conv_weight = model.layers[il].ssm_conv1d;
@@ -19326,6 +19319,14 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
1932619319

1932719320
return cur;
1932819321
}
19322+
19323+
ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
19324+
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, dt_bias); // a + dt_bias
19325+
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
19326+
ggml_tensor * one_plus_exp = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f); // 1 + exp(a + dt_bias)
19327+
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
19328+
return alpha_softplus;
19329+
}
1932919330
};
1933019331

1933119332

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ struct llama_layer {
228228
struct ggml_tensor * wk_enc = nullptr;
229229
struct ggml_tensor * wv_enc = nullptr;
230230
struct ggml_tensor * wo_enc = nullptr;
231+
struct ggml_tensor * wq_gate = nullptr;
231232

232233
// attention bias
233234
struct ggml_tensor * bq = nullptr;

0 commit comments

Comments
 (0)