-
Notifications
You must be signed in to change notification settings - Fork 155
Add GLM-4-0414 Model Support #344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Does it also happen when you use |
|
Hrrm, unfortunately no using Without I also tried city96's patch to force Could be that I made a mistake in the EDIT: mainline was compiled for CPU only so this is to be expected: Last observations are that mainline seems to work fine with or without Not sure what to try next other than dig in deeper to how |
If you made a mistake with building the graph, this invocation wouldn't be working. If it works for all layers offloaded to the GPU except attention tensors and KV cache, it means there is a precision issue in the attention calculation on CUDA (on the CPU everything is computed with |
|
I just noticed one more odd thing trying Running mainline with |
|
Try this: in the function add This will set the precision of the |
|
I see in mainline ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);This is why mainline may be working for this model. I still refuse to set that generically for all models as this hurts performance for long contexts quite a bit. The downside is that one needs to explicitly enable |
Yes, this fixed the issue, I can fully offload now! I'll push this up. Remaining questions:
|
|
You have only enabled You don't need the latest PR in mainline that sets |
|
Okay, so now without I'll look for a reference, I thought I've seen others mentioning this kinda output before. Here is a reference where they suggest using different batch size e.g. Another reference here which seems to suggest a recent python conversion update here. So maybe I'll double check my existing GGUF or try to convert my own GGUF using the most recent patch that updates some special tokens and sets Seems like bartowski used a version of mainline to convert that did include this PR hrmm.. |
src/llama.cpp
Outdated
| auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12); | ||
| auto kq_i = ggml_mul_mat(ctx, k_i, q_i); | ||
| if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { | ||
| if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_GLM4) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
if (model.arch == LLM_ARCH_GLM4) {
ggml_mul_mat_set_prec(kqv_i, GGML_PREC_F32);
}after line 9515
| if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_GLM4) { | ||
| // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs | ||
| // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 | ||
| ggml_mul_mat_set_prec(kq, GGML_PREC_F32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
if ( model.arch == LLM_ARCH_GLM4) {
ggml_mul_mat_set_prec(kqv, GGML_PREC_F32);
}after line 9475
|
I don't think any of the suggestions you are finding around the Internet are going to help. Just think about it:
The only logical conclusion from these 3 observations is that you also need to set the precision to |
|
Thanks, I appreciate you helping me learn on this. Just to be clear I'm getting the gibberish output without I tried setting precision to fp32 as you describe, but still get the same gibberish. The patch you suggested above.I went ahead and tried this and it seems to be taking the `kqv` path and not the `kqv_i` but still giving same gibberish.I'll dig into the differences between mainline non flash attention and this forks non flash attention path more to see if anything else sticks out to me. |
Sorry, I missed the fact that it is not working on the CPU without FA. If I had paid better attention, I would have diagnosed the problem much earlier. Simply remove the line (line 15686 in the version I just cloned from your repository) Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);In mainline they have reorganized how attention is built. Reshaping |
|
Here a quick CPU only Mainline
ik_llama.cpp(but I needed the changes in PR #349 to make FA work on the CPU).
|
|
Sweeet that fixes up the non-flash-attention case! This model is quite efficient, I just ran it with 128k context and only using For now I'll fixup this PR and put it after your recent cohere2 additions and set it to review ready afterwards. Thanks again really appreciate your time looking at this! Cheers! |
Yes, it has a very high GQA factor of 24, so the KV entries per token are very small. This makes the attention portion very efficient, so the decline of TG speed with context in the KV cache is very slow (less than 10% when going from 0 to 8k tokens as per above table). So, it is a model worth having. Please make it ready and let's merge it. |
Based on zRzRzRzRzRzRzR's PR on mainline llama.cpp. Still some issues where it doesn't work: * offloading >=60 layers to GPU * no flash attention
Both of these seem unused and LLM_TENSOR_ATTN_POST_NORM already existed which seems pretty similar? Don't think they were used in the python code either... So removed these as possibly just cruft: * LLM_TENSOR_POST_ATTN_NORM * LLM_TENSOR_POST_MLP_NORM
This fixes the non-flash-attention inferencing on both CPU and CUDA.
|
Okay got it rebased, gonna force push it up after quick final test!!! |
10ec675 to
6ef4fba
Compare
|
Yaay!! Feels good to finally get that model working haha... Thanks again for your patience and guidance! Have a g'night! |
I followed your lead and ran some ik's CPU-only testmy CPU-only test👈 Logs
|
This caught my eye, and looked into it and found they had a prior work dedicated to long context training of LLMs that they say "(Cf LongAlign: A Recipe for Long Context Alignment of Large Language Models for technical details)" in the GQA part of their technical report |
|
I found this where someone uses NoLiMa to test the long context performance and they did notice lower performance (which I believe is because of the very high GQA factor). |





This is my second attempt which still has some issues. Original attempt was #333. This one is based on ggml-org/llama.cpp#12867 . However, this PR does not bring over any of the python stuff.
In limited testing with of bartowski/THUDM_GLM-Z1-32B-0414-GGUF on CPU only and CUDA backends it seems to work as long as:
-faExample Command
This is one way to run it on CUDA that seems to work:
If I increase
--n-gpu-layers 60or higher, it outputsGGGGGGGGGGGGGGG.It also seems okay to add
-amb 512 -ctk q8_0 -ctv q8_0...fwiw there seems to be some issues still on mainline implementation possibly related:
So I'll mark this as draft for now and see how things are looking soon.