Skip to content

Commit 7abe495

Browse files
committed
Add Qwen3-VL support from llama.cpp PR #16780
Integrates Qwen3-VL and Qwen3VL-MoE architecture support from upstream. Implements IMROPE (Interleaved Multi-resolution RoPE) for vision models. Adds deepstack layer support for visual feature processing. Changes include: - New architecture types: LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE - IMROPE rope type for vision position encoding - Deepstack visual feature handling in clip.cpp - GGML CUDA kernels for IMROPE - Tensor mappings for Qwen3VL architecture Upstream PR: ggml-org/llama.cpp#16780 Contributors: @JJJYmmm @yairpatch @Thireus @LETS-BEE
1 parent edae0c2 commit 7abe495

File tree

14 files changed

+2005
-49
lines changed

14 files changed

+2005
-49
lines changed

llama/llama.cpp/include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ extern "C" {
8383
LLAMA_ROPE_TYPE_NORM = 0,
8484
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
8585
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
86+
LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE,
8687
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
8788
};
8889

llama/llama.cpp/src/llama-arch.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3131
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
3232
{ LLM_ARCH_QWEN3, "qwen3" },
3333
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
34+
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
35+
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
3436
{ LLM_ARCH_PHI2, "phi2" },
3537
{ LLM_ARCH_PHI3, "phi3" },
3638
{ LLM_ARCH_PHIMOE, "phimoe" },
@@ -142,6 +144,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
142144
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
143145
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
144146
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
147+
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
145148
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
146149
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
147150
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -753,6 +756,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
753756
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
754757
},
755758
},
759+
{
760+
LLM_ARCH_QWEN3VL,
761+
{
762+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
763+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
764+
{ LLM_TENSOR_OUTPUT, "output" },
765+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
766+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
767+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
768+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
769+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
770+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
771+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
772+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
773+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
774+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
775+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
776+
},
777+
},
756778
{
757779
LLM_ARCH_QWEN3MOE,
758780
{
@@ -773,6 +795,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
773795
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
774796
},
775797
},
798+
{
799+
LLM_ARCH_QWEN3VLMOE,
800+
{
801+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
802+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
803+
{ LLM_TENSOR_OUTPUT, "output" },
804+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
805+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
806+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
807+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
808+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
809+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
810+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
811+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
812+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
813+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
814+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
815+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
816+
},
817+
},
776818
{
777819
LLM_ARCH_PHI2,
778820
{

llama/llama.cpp/src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ enum llm_arch {
3535
LLM_ARCH_QWEN2VL,
3636
LLM_ARCH_QWEN3,
3737
LLM_ARCH_QWEN3MOE,
38+
LLM_ARCH_QWEN3VL,
39+
LLM_ARCH_QWEN3VLMOE,
3840
LLM_ARCH_PHI2,
3941
LLM_ARCH_PHI3,
4042
LLM_ARCH_PHIMOE,
@@ -146,6 +148,7 @@ enum llm_kv {
146148
LLM_KV_EXPERTS_PER_GROUP,
147149
LLM_KV_MOE_EVERY_N_LAYERS,
148150
LLM_KV_NEXTN_PREDICT_LAYERS,
151+
LLM_KV_NUM_DEEPSTACK_LAYERS,
149152
LLM_KV_POOLING_TYPE,
150153
LLM_KV_LOGIT_SCALE,
151154
LLM_KV_DECODER_START_TOKEN_ID,

llama/llama.cpp/src/llama-hparams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
148148
}
149149

150150
uint32_t llama_hparams::n_pos_per_embd() const {
151-
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
151+
return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
152152
}
153153

154154
bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const {

llama/llama.cpp/src/llama-hparams.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ struct llama_hparams {
183183
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
184184
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
185185

186+
// qwen3vl deepstack
187+
uint32_t n_deepstack_layers = 0;
188+
186189
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
187190
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
188191
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

llama/llama.cpp/src/llama-kv-cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
13401340
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
13411341

13421342
const auto & n_rot = hparams.n_rot;
1343-
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
1343+
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
13441344
// @ngxson : this is a workaround
13451345
// for M-RoPE, we want to rotate the whole vector when doing KV shift
13461346
// a normal RoPE should work, we just need to use the correct ordering

0 commit comments

Comments
 (0)