- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.5k
[model] add support for qwen3vl series #16780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Thireus ☠ <[email protected]> Co-authored-by: yairpatch <[email protected]> Co-authored-by: LETS-BEE <[email protected]>
| Thank you @JJJYmmm! Test builds: https://github.com/Thireus/llama.cpp/releases - tagged  | 
| 
 For some reason, this version's OCR capability is not as good as the previous LETS-BEE version; it noticeably misses characters and exhibits infinite repetition. | 
Integrates Qwen3-VL and Qwen3VL-MoE architecture support from upstream. Implements IMROPE (Interleaved Multi-resolution RoPE) for vision models. Adds deepstack layer support for visual feature processing. Changes include: - New architecture types: LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE - IMROPE rope type for vision position encoding - Deepstack visual feature handling in clip.cpp - GGML CUDA kernels for IMROPE - Tensor mappings for Qwen3VL architecture Upstream PR: ggml-org/llama.cpp#16780 Contributors: @JJJYmmm @yairpatch @Thireus @LETS-BEE
| the question is: are the fixes in #16745 included in this PR? If not, the full performance of the model will only be reached with PR 16475 merged. | 
| 
 I'm still getting an unknown model architecture error here?  | 
| 
 they are not, as @FMayran and @rujialiu are still figuring out the best way to implement a fix properly, once and for all :) . you can cherry-pick the changes from #16745 without any problems though, and then just build it yourself for a temporary implementation, though make sure to check the issues raised in the last 24-48 hours re why its not a real 100% fix | 
| I have managed to get Qwen3-LV-30B-A3B-instruct running on Ubuntu just now (specifically with a ryzen ai max+ 395 and vulkan). Did you compile your own GGUF/mmproj.gguf using  How I prepared mine below 
 
 
 
 No GGUFs I found off the shelf were working right until I did this. Hope this helps. | 
| 
 Thank you. I was using the GGUFs from NexaAI. May I add though that I think the architecture is different for each model (30B/8B/4B) etc. I will try this though, thanks again | 
| deepstack_features = feat; | ||
| } else { | ||
| // concat along the feature dimension | ||
| deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not very important to optimize this right now, but doing ggml_concat at on multiple layers can increase memory usage. one trick is to allocate one big tensor, then use ggml_set_rows to copy the intermediate result into the allocated tensor.
cc @ggerganov , do you think this can be a good idea for concat multiple tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I just follow the style of llava
Lines 1278 to 1285 in 1c1409e
| // If feature layers are explicitly set, stack them (if we have multiple) | |
| if (!embedding_stack.empty()) { | |
| embeddings = embedding_stack[0]; | |
| for (size_t i = 1; i < embedding_stack.size(); i++) { | |
| embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); | |
| } | |
| } | |
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but llava has fixed number of token (no dynamic resolution), so the memory usage is predictable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, I’ll optimize it later. 🫡
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a TODO comment with a reference to this thread to not forget to improve this later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); | |
| // TODO: pre-allocate memory and use ggml_set_rows, see: https://github.com/ggml-org/llama.cpp/pull/16780/files#r2465886647 | |
| deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); | 
| @PaymonHossaini,  | 
| 
   While its true the the 30B is MOE and the 8B is dense I was unable to recreate this issue. Make sure your local checkout tracks the PR branch as there were some changes to that script to make it compatible with these models. My instructions for using 8B model below
 
 
 
 I don't belive this issue is a result of the code changes. | 
| 
 I think merging PR #16745 will likely reflect the model's original performance. | 
| @JJJYmmm Adding Vulkan too? :) | 
| @CISC I’ve updated the corresponding file, but haven’t tested it yet since I don’t have a vulkan env at the moment. | 
| GLSL cannot automatically convert integers to bool, so you need the full condition, for example  | 
| self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0) | ||
| for idx in self.hparams_vision.get("deepstack_visual_indexes", []): | ||
| self.is_deepstack_layers[idx] = True | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(No actions is needed, just a side note here)
The is_deepstack_layers metadata is no longer being used in clip.cpp, as I want to make the code more simple to maintain. We now use the same logic as MoE in llama.cpp, where if the tensor is not present, it will be nullptr, and this will trigger the code branch for deepstack layers
Bu we will still keep this metadata in GGUF for future use
Co-authored-by: Sigbjørn Skjæret <[email protected]>
| I see an error in the mmproj creation:  | 
| As a reminder, you can also add different backends support in follow-up PRs, to avoid adding too many reviewers into one PR (More preferable, one PR per backend) | 
| I'm merging this in the next 30mn - 1hr as the CI for  | 
| Thank you all for the detailed review! 🙏 | 
| I never watched a llama.cpp PR thread before, never realized how well-organized and dedicated you all are, just wanted to chime in to say: you all rock and your effort is appreciated. | 
| 
 Ditto. I've been watching this PR like a hawk. Great contributors and great maintainers all around. | 
| I believe the requirements.txt needs top be updated, the current transformers version does not have support for the qwen3-vl architecture. Not a problem for inference, but for quantizing it will not recognize the arch | 
This PR adds support for the Qwen3-VL series, including both the dense and MoE variants.
The original implementation was contributed by @yairpatch and @Thireus (see #16207). @LETS-BEE also helped address issues such as weights loading.
In this PR, I’ve fixed several algorithmic implementation details (e.g., deepstack), added support for MRoPE-Interleave, and performed final code cleanup.