Skip to content

Conversation

pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Sep 18, 2025

It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.

Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.

Resolves #15940

@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Sep 18, 2025
@gabe-l-hart
Copy link
Collaborator

I'll try to get into it in more detail soon, but here are a few general thoughts after quickly skimming the PR:

  1. The structure of what you've got smells correct, so it's likely close, but missing something small yet critical
  2. A full repro with the error it's raising would definitely help debug
  3. My debugging process for this would be:
    1. Make sure tokenization is solid (print statements as necessary to compare tokens before input)
    2. Use llama-eval-callback to dump tensors for a single prefill step
    3. Run an identical single prefill with the reference impl (transformers or otherwise), and inject prints as needed to dump tensors along the way
    4. Visually comb through them (particularly the sum at each point) to see where things start diverging significantly

@bugparty
Copy link
Contributor

It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.

Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.

Resolves #15940

interesting, maybe we can learn together

ggml/src/ggml.c Outdated
Comment on lines 5467 to 5483
if (use_qk_l2norm) {
q_norm = ggml_l2_norm(ctx, q, 1e-6f);
k_norm = ggml_l2_norm(ctx, k, 1e-6f);
}

// Apply scaling to query
q_norm = ggml_scale(ctx, q_norm, scale);

// Apply sigmoid to beta for gating
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 1);
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);

u_int32_t dim = (S_v * H_v) + 2 * (H_k * S_k);

mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, 1, dim, n_tokens);
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, 3, 0, 0, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of code has namy magic number and configs (like l2norm, sigmoid, silu). It will be a headache if a future model reuse this delta net idea with some tweaks. It's better to just move al this part to ggml-model and the make ggml_delta_net being a thin wrapper around GGML_OP_DELTA_NET, like all other ops.

int64_t ne3) {
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);

ggml/src/ggml.c Outdated
Comment on lines 5562 to 5563
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, repeat_factor, H_k, n_tokens);
k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, repeat_factor, H_k, n_tokens);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe repeat_factor can be a param for GGML_OP_DELTA_NET, so it can internally do the broadcast without using extra memory

ggml/src/ggml.c Outdated
k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3);

q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, 1, n_tokens);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ggml_cont_3d is the combination of reshape and cont

}

// Apply sigmoid to beta
float * beta_sigmoid = (float *)alloca(n_tokens * sizeof(float));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using working data (params->wdata) can be a better choice

// Apply sigmoid to beta
float * beta_sigmoid = (float *)alloca(n_tokens * sizeof(float));
for (int64_t t = 0; t < n_tokens; ++t) {
beta_sigmoid[t] = 1.0f / (1.0f + expf(-beta_ptr[t * nb42 / sizeof(float)]));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't beta already be sigmoid-ed before passing to this op? you're doing sigmoid 2nd time here IIUC


// ggml_compute_forward_delta_net

static void ggml_compute_forward_delta_net(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this op can be implemented using other ggml ops like mul, mul_mat, sum. Which part of the calculation do you think that can't be constructed using existing ops?

@pwilkin pwilkin marked this pull request as draft September 19, 2025 08:07
@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

  1. A full repro with the error it's raising would definitely help debug

Running llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are " yields this weird memory error:

#0  __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56      in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1  0x000070552b29eb63 in __internal_syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=0, a6=0, nr=61) at ./nptl/cancellation.c:49
warning: 49     ./nptl/cancellation.c: No such file or directory
#2  __syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75      in ./nptl/cancellation.c
#3  0x000070552b31afdf in __GI___wait4 (pid=<optimized out>, stat_loc=<optimized out>, options=<optimized out>, usage=<optimized out>) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30     ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4  0x000070552bb45c31 in ggml_print_backtrace () at /devel/tools/llama.cpp/ggml/src/ggml.c:196
warning: Source file is more recent than executable.
196             waitpid(child_pid, NULL, 0);
#5  0x000070552bb45de5 in ggml_abort (file=0x70552bbcdac8 "/devel/tools/llama.cpp/ggml/src/ggml-backend.cpp", line=189, fmt=0x70552bbcd8af "GGML_ASSERT(%s) failed") at /devel/tools/llama.cpp/ggml/src/ggml.c:230
230             ggml_print_backtrace();
#6  0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189         GGML_ASSERT(buffer);
#7  0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170         return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
#8  0x000070552c07a114 in llm_graph_input_rs::set_input (this=0x5f11bdf6aea0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:241
241             GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
#9  0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437         inp_rs->set_input(ubatch);
#10 0x000070552c07b549 in llm_graph_result::set_inputs (this=0x5f11be01ddf0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:480
480             input->set_input(ubatch);
#11 0x000070552c01ddb3 in llama_context::process_ubatch (this=0x5f11c05b5b50, ubatch=..., gtype=LLM_GRAPH_TYPE_DECODER, mctx=0x5f11be00ff00, ret=@0x7fff74d22ea4: 538976288) at /devel/tools/llama.cpp/src/llama-context.cpp:779
779             res->set_inputs(&ubatch);
#12 0x000070552c01f367 in llama_context::decode (this=0x5f11c05b5b50, batch_inp=...) at /devel/tools/llama.cpp/src/llama-context.cpp:1088
1088            const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
#13 0x000070552c025e49 in llama_decode (ctx=0x5f11c05b5b50, batch=...) at /devel/tools/llama.cpp/src/llama-context.cpp:2726
2726        const int ret = ctx->decode(batch);
#14 0x00005f11a2021559 in common_init_from_params (params=...) at /devel/tools/llama.cpp/common/common.cpp:1066
1066                llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
#15 0x00005f11a1e4a3c0 in main (argc=7, argv=0x7fff74d25968) at /devel/tools/llama.cpp/tools/main/main.cpp:140
140         common_init_result llama_init = common_init_from_params(params);

I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.

@CISC
Copy link
Collaborator

CISC commented Sep 19, 2025

  1. A full repro with the error it's raising would definitely help debug

Running llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are " yields this weird memory error:

...
#6  0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189         GGML_ASSERT(buffer);
#7  0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170         return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
...

The backend buffer is NULL.

@ngxson
Copy link
Collaborator

ngxson commented Sep 19, 2025

#9  0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437         inp_rs->set_input(ubatch);

The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.

I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.

Hmm I think I said the reverse: not to merge it but make the op simple

I feel like this op can be implemented using other ggml ops like mul, mul_mat, sum. Which part of the calculation do you think that can't be constructed using existing ops?

This is the more important question: should we try to implement it using existing ops, or add a new op and spend even more time to optimize it cross all backends?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

Now this is an error I haven't expected to encounter:

GGML_ABORT("not enough space in the context's memory pool");

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.

How do I allocate the memory for the linear layers then? I seem to have misunderstood how build_inp_mem_hybrid() works...

@yarikdevcom
Copy link

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Sep 19, 2025

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)

I send a coffee also.

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

GGML_ABORT("not enough space in the context's memory pool");

Probably there are too many nodes on cgraph, try increasing the limit via llama_context::graph_max_nodes()

Comment on lines 19054 to 19056
Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these ggml_cont can be removed if Q/gate are separated. ggml_cont is not recommended when dealing with big tensors

Copy link
Collaborator

@CISC CISC Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually none of these need ggml_cont, Q is 3D already, Q/K are RoPEd so can be views and V can also be a 3D view now.

Edit: sorry, not quite true about V, only if QKV is fused, the weird gate fuse threw me off. Nevertheless, K/V are already contiguous at this point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

Copy link
Collaborator

@CISC CISC Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

Are you sure? AFAIK those issues are fixed.

Edit: Also, if there still are issues they will never get fixed if we work around them. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

I think all of these cases are fixed now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an impl of 2D rope that relies on ggml_view: https://github.com/ngxson/ggml-easy/blob/f56e5e499b1f21a4aae73010e9d9582840428457/demo/2d-rope.cpp

It works on CPU and Metal, but doesn't work on CUDA/Vulkan. Couldn't tested on other backends, but feel free to make a PR to address this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that seems to work. sorry @pwilkin you will need to manually revert the change where I split Q/gate. the tensor shape for Q will be:

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);

layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape of LLM_TENSOR_ATTN_Q and LLM_TENSOR_SSM_OUT should not contain n_ff

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

^ proposed fix for the 3 comments above: 46110e0

return cur;
};

ggml_tensor * llm_build_qwen3next::softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably get numerical instability here without a threshold, like log(1+exp(1000)) = Inf.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea which "here" you're referring to, but RMS norm is clamped by rms_eps (10e-6 in case of Qwen3Next)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the softplus implementation here.

Copy link

@theo77186 theo77186 Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for large activations as exp(89) is already bigger than any f32. But if that happens, there would be NaNs everywhere. Another potential issue is when calculating log(1+x) for small x, but unfortunately GGML doesn't have a log1p primitive (nor expm1 by the way).

@theo77186
Copy link

Note for those testing this: the latest commit introduced changes to gguf conversion, that means gguf files need to be reconverted.

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Oct 14, 2025

Note for those testing this: the latest commit introduced changes to gguf conversion, that means gguf files need to be reconverted.

Sans titre 2.6tok/s -> 6.x tok/s !!! (-dev none, CPU) Sans titre

Finally got it to give only one answer. I 'fixed' the crash. Introduced emotional damage. Model keeps saying: ‘I have too many errors.'
Sans titre

@ServeurpersoCom
Copy link
Collaborator

It seems the ssm_conv1d.bias tensor might be missing or not loaded correctly the GGUF includes it, but the model graph doesn’t seem to use it yet ?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Oct 14, 2025

Eh, we're not going to avoid a GGML_OP_DELTA_RECURRENT after all. A graph with 7000 nodes is kind of cute, but not very practical (and the errors add up really quickly).

@JohannesGaessler I found a first potential culprit, RMS norm CUDA implementation diverges from the CPU implementation (and the PyTorch implementation):

Reference (PyTorch):
ggml_debug: model.layers.0.input_layernorm.forward.out = (f32)  ... = {torch.Size([1, 4, 256])}
                                     [
                                      [
                                       [      0.9355,       0.3888,       0.9045, ...,       1.5365,       1.6356,       1.0099]
                                       [      1.3267,      -2.2407,      -2.9482, ...,      -1.5095,      -1.8161,       1.9576]
                                       [     -1.1757,       0.7235,       1.5109, ...,      -1.3619,       1.3832,      -2.3728]
                                       [     -0.7574,      -0.7800,      -4.4085, ...,      -1.0626,      -0.7234,      -1.1360]
                                      ],
                                     ]
                                     sum = -55.342392

ggml_debug: model.layers.0.linear_attn.in_proj_qkvz.forward.out = (f32)  ... = {torch.Size([1, 4, 768])}
                                     [
                                      [
                                       [     -0.5138,       0.1238,      -0.4653, ...,       0.1122,      -0.1639,       0.4393]
                                       [      0.7761,       0.2072,       0.3530, ...,       0.7795,      -0.0871,      -1.0032]
                                       [      0.3069,       0.0373,       1.8069, ...,      -0.2349,       0.4950,       0.1955]
                                       [      0.6911,      -0.4617,       0.2483, ...,       0.4156,       0.7836,       0.1736]
                                      ],
                                     ]
                                     sum = -34.217533
=========
GGML (CPU):
ggml_debug:              attn_norm-0 = (f32)        MUL(norm-0{256, 4, 1, 1}, blk.0.attn_norm.weight{256, 1, 1, 1}}) = {256, 4, 1, 1}
                                     [
                                      [
                                       [      0.9355,       0.3888,       0.9045,       2.3667,      -0.7837,       1.1586,      -5.2043,      -2.4782, ...,      -2.2428,       3.1349,      -3.5934,      -1.3568,      -1.3444,       1.5365,       1.6356,       1.0099],
                                       [      1.3267,      -2.2407,      -2.9482,       2.2760,       0.8019,      -0.9375,      -1.5449,       1.5095, ...,       1.5095,       2.5826,       1.9812,       0.7783,       2.7595,      -1.5095,      -1.8161,       1.9576],
                                       [     -1.1757,       0.7235,       1.5109,       2.7239,      -1.7982,      -1.8620,      -3.3623,       2.1174, ...,      -3.0644,      -1.3619,      -1.3619,      -1.7237,       1.3619,      -1.3619,       1.3832,      -2.3728],
                                       [     -0.7574,      -0.7800,      -4.4085,      -4.0467,       0.7517,       1.6390,      -5.6067,       0.8252, ...,       1.4186,      -0.7404,      -1.2152,       0.8647,       0.8421,      -1.0626,      -0.7234,      -1.1360],
                                      ],
                                     ]
                                     sum = -55.342415

ggml_debug: linear_attn_mixed_qkvz-0 = (f32)    MUL_MAT(blk.0.ssm_in.weight{256, 768, 1, 1}, attn_norm-0{256, 4, 1, 1}}) = {768, 4, 1, 1}
                                     [
                                      [
                                       [     -0.5138,       0.1238,      -0.4653,       1.2103,       0.3890,      -0.2325,      -0.0746,      -1.0886, ...,      -0.3519,      -0.2448,      -0.9264,       0.3835,       0.8214,       0.1122,      -0.1639,       0.4393],
                                       [      0.7761,       0.2072,       0.3530,       0.6750,      -0.7315,       0.2660,      -0.0720,      -0.2165, ...,       0.0448,      -0.9420,      -1.0138,      -0.5751,      -0.3338,       0.7795,      -0.0871,      -1.0032],
                                       [      0.3069,       0.0373,       1.8069,       0.8821,       0.4421,      -0.1695,      -0.2114,      -0.3937, ...,       0.8654,       0.4693,       0.0252,       0.0804,       0.5899,      -0.2349,       0.4950,       0.1955],
                                       [      0.6911,      -0.4617,       0.2483,      -0.1024,      -1.4867,       0.4988,       0.8328,       0.7382, ...,       0.0097,       0.1196,       0.7835,       0.1981,      -0.6273,       0.4156,       0.7836,       0.1736],
                                      ],
                                     ]
                                     sum = -34.217514
=========
GGML (CUDA):
ggml_debug:              attn_norm-0 = (f32)        MUL(norm-0{256, 4, 1, 1}, blk.0.attn_norm.weight{256, 1, 1, 1}}) = {256, 4, 1, 1}
                                     [
                                      [
                                       [      0.9355,       0.3888,       0.9045,       2.3667,      -0.7837,       1.1586,      -5.2043,      -2.4782, ...,      -2.2428,       3.1349,      -3.5934,      -1.3568,      -1.3444,       1.5365,       1.6356,       1.0099],
                                       [      1.3267,      -2.2407,      -2.9482,       2.2760,       0.8019,      -0.9375,      -1.5449,       1.5095, ...,       1.5095,       2.5826,       1.9812,       0.7783,       2.7595,      -1.5095,      -1.8161,       1.9576],
                                       [     -1.1757,       0.7235,       1.5109,       2.7239,      -1.7982,      -1.8620,      -3.3623,       2.1174, ...,      -3.0644,      -1.3619,      -1.3619,      -1.7237,       1.3619,      -1.3619,       1.3832,      -2.3728],
                                       [     -0.7574,      -0.7800,      -4.4085,      -4.0467,       0.7517,       1.6390,      -5.6067,       0.8252, ...,       1.4186,      -0.7404,      -1.2152,       0.8647,       0.8421,      -1.0626,      -0.7234,      -1.1360],
                                      ],
                                     ]
                                     sum = -55.342426
                                     
ggml_debug: linear_attn_mixed_qkvz-0 = (f32)    MUL_MAT(blk.0.ssm_in.weight{256, 768, 1, 1}, attn_norm-0{256, 4, 1, 1}}) = {768, 4, 1, 1}
                                     [
                                      [
                                       [     -0.5136,       0.1238,      -0.4652,       1.2096,       0.3888,      -0.2325,      -0.0746,      -1.0884, ...,      -0.3518,      -0.2448,      -0.9262,       0.3834,       0.8213,       0.1121,      -0.1639,       0.4393],
                                       [      0.7757,       0.2072,       0.3529,       0.6744,      -0.7311,       0.2658,      -0.0719,      -0.2166, ...,       0.0448,      -0.9416,      -1.0135,      -0.5748,      -0.3337,       0.7793,      -0.0870,      -1.0025],
                                       [      0.3069,       0.0372,       1.8060,       0.8818,       0.4423,      -0.1692,      -0.2114,      -0.3936, ...,       0.8652,       0.4693,       0.0250,       0.0804,       0.5894,      -0.2346,       0.4949,       0.1958],
                                       [      0.6908,      -0.4617,       0.2482,      -0.1024,      -1.4864,       0.4985,       0.8326,       0.7379, ...,       0.0094,       0.1195,       0.7830,       0.1980,      -0.6268,       0.4156,       0.7832,       0.1738],
                                      ],
                                     ]
                                     sum = -34.227242

@pwilkin
Copy link
Collaborator Author

pwilkin commented Oct 14, 2025

It seems the ssm_conv1d.bias tensor might be missing or not loaded correctly the GGUF includes it, but the model graph doesn’t seem to use it yet ?

Wdym?

ggml_tensor * llm_build_qwen3next::softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, **dt_bias**);                // a + dt_bias
    ggml_tensor * alpha_exp      = ggml_exp(ctx0, alpha_biased);                  // exp(a + dt_bias)
    ggml_tensor * one_plus_exp   = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f);  // 1 + exp(a + dt_bias)
    ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp);                  // log(1 + exp(...))
    return alpha_softplus;
}

@pwilkin
Copy link
Collaborator Author

pwilkin commented Oct 14, 2025

You might be confused because of this little detail:

                            layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);

i.e. there's no dt weight tensor in this model.

@JohannesGaessler
Copy link
Collaborator

@pwilkin the biggest difference between the CPU and CUDA backends is that the CPU backend uses double precision for the summation of the squared values while CUDA uses single precision. You can test whether the difference is actually meaningful with this patch:

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index da312992c..1bb2afb65 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3550,7 +3550,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
             break;
         case GGML_OP_NORM:
+            return true;
         case GGML_OP_RMS_NORM:
+            return false;
         case GGML_OP_L2_NORM:
             return true;
         case GGML_OP_RMS_NORM_BACK:

It will force the RMS norm to run on CPU. In terms of performance it would be vastly preferable if double precision arithmetic is not needed because NVIDIA consumer GPUs have gimped FP64 performance.

@CoruNethron
Copy link

I love how @pwilkin comes up with the idea to print something 5-dimentional:
Screenshot_20251015-091528_Chrome

Sorry for spam.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Oct 14, 2025

It's not me, it's Qwen, they are using 5D tensors in their chunked delta_net function :)

@gopinath87607
Copy link

hi guys is this done ? any gguf available btw i just tested this gguf but its not doing great.

https://huggingface.co/AesSedai/Qwen3-Next-80B-A3B-Instruct-GGUF

@k3d3
Copy link

k3d3 commented Oct 16, 2025

@gopinath87607 It will be done when it is done. Be patient, let pwilkin do their incredible magic, and please don't spam this issue as a ton of people are watching it. That goes for anyone else who asks for status updates too - if you want those, subscribe like the rest of us. :)

Ideally, don't even respond to me - just leave a react or something.

Thank you. (and apologies for myself adding to the noise)

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs labels Oct 16, 2025
@gabe-l-hart
Copy link
Collaborator

@pwilkin FYI, probably not that useful right now, but I've got Metal and CUDA implementations of your CUMSUM and TRI kernels up in #16623. From what I can tell, I don't think these have been tackled on those backends yet anywhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Qwen3-Next support