Skip to content

Commit bd75d0f

Browse files
authored
Merge pull request #32 from JJJYmmm/add_qwen3vl
qwen3vl - code clean + use fused qkv in clip
2 parents 5f3b32f + 2be9279 commit bd75d0f

File tree

6 files changed

+47
-54
lines changed

6 files changed

+47
-54
lines changed

convert_hf_to_gguf.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4147,24 +4147,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
41474147
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]
41484148

41494149
if name.startswith("visual."):
4150-
if ".qkv." in name:
4151-
if data_torch.ndim == 2:
4152-
c3, _ = data_torch.shape
4153-
else:
4154-
c3 = data_torch.shape[0]
4155-
if c3 % 3 != 0:
4156-
raise ValueError(f"Unexpected QKV shape for {name}: {data_torch.shape}")
4157-
c = c3 // 3
4158-
wq = data_torch[:c]
4159-
wk = data_torch[c: c * 2]
4160-
wv = data_torch[c * 2:]
4161-
base = name.replace("qkv", "{placeholder}")
4162-
return [
4163-
(self.map_tensor_name(base.format(placeholder="q")), wq),
4164-
(self.map_tensor_name(base.format(placeholder="k")), wk),
4165-
(self.map_tensor_name(base.format(placeholder="v")), wv),
4166-
]
4167-
41684150
return [(self.map_tensor_name(name), data_torch)]
41694151

41704152
# Fall back to parent class for other tensors

gguf-py/gguf/constants.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ class ClipVision:
278278
USE_GELU = "clip.use_gelu"
279279
USE_SILU = "clip.use_silu"
280280
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
281-
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
281+
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
282282

283283
class Attention:
284284
HEAD_COUNT = "clip.vision.attention.head_count"
@@ -614,6 +614,7 @@ class MODEL_TENSOR(IntEnum):
614614
V_ENC_EMBD_PATCH = auto()
615615
V_ENC_EMBD_POS = auto()
616616
V_ENC_INPUT_NORM = auto()
617+
V_ENC_ATTN_QKV = auto()
617618
V_ENC_ATTN_Q = auto()
618619
V_ENC_ATTN_Q_NORM = auto()
619620
V_ENC_ATTN_K = auto()
@@ -646,8 +647,8 @@ class MODEL_TENSOR(IntEnum):
646647
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
647648
V_MM_PATCH_MERGER = auto() # mistral small 3.1
648649
V_DS_NORM = auto() # qwen3vl
649-
V_DS_FC1 = auto() # qwen3vl
650-
V_DS_FC2 = auto() # qwen3vl
650+
V_DS_FC1 = auto() # qwen3vl
651+
V_DS_FC2 = auto() # qwen3vl
651652
# audio (mtmd)
652653
A_ENC_EMBD_POS = auto()
653654
A_ENC_CONV1D = auto()
@@ -964,6 +965,7 @@ class MODEL_TENSOR(IntEnum):
964965
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
965966
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
966967
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
968+
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
967969
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
968970
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm",
969971
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
@@ -1036,6 +1038,7 @@ class MODEL_TENSOR(IntEnum):
10361038
MODEL_TENSOR.V_ENC_EMBD_PATCH,
10371039
MODEL_TENSOR.V_ENC_EMBD_POS,
10381040
MODEL_TENSOR.V_ENC_INPUT_NORM,
1041+
MODEL_TENSOR.V_ENC_ATTN_QKV,
10391042
MODEL_TENSOR.V_ENC_ATTN_Q,
10401043
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
10411044
MODEL_TENSOR.V_ENC_ATTN_K,

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,10 @@ class TensorNameMap:
11881188
"visual.pos_embed", # qwen3vl
11891189
),
11901190

1191+
MODEL_TENSOR.V_ENC_ATTN_QKV: (
1192+
"visual.blocks.{bid}.attn.qkv", # qwen3vl
1193+
),
1194+
11911195
MODEL_TENSOR.V_ENC_ATTN_Q: (
11921196
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
11931197
"model.vision_tower.encoder.layer.{bid}.attention.q_proj", # Intern-S1

src/llama-model.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10271027
} break;
10281028
case LLM_ARCH_QWEN3VL:
10291029
{
1030-
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, 0);
1030+
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
10311031
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
10321032
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
10331033
switch (hparams.n_layer) {
@@ -1036,8 +1036,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10361036
case 64: type = LLM_TYPE_32B; break;
10371037
default: type = LLM_TYPE_UNKNOWN;
10381038
}
1039-
// for deepstack patch, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...]
1040-
hparams.n_embd = hparams.n_embd * (hparams.n_deepstack_layers + 1);
1039+
// since vision model stacks deepstack features along feature dim
1040+
// we also create a fake "n_embd" for text model to be the main embd + deepstack embds
1041+
hparams.n_embd *= hparams.n_deepstack_layers + 1;
10411042
} break;
10421043
case LLM_ARCH_QWEN3MOE:
10431044
{
@@ -1052,17 +1053,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
10521053
} break;
10531054
case LLM_ARCH_QWEN3VLMOE:
10541055
{
1055-
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, 0);
1056+
ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
10561057
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true);
1057-
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
1058+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
10581059
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
10591060
switch (hparams.n_layer) {
10601061
case 48: type = LLM_TYPE_30B_A3B; break;
10611062
case 94: type = LLM_TYPE_235B_A22B; break;
10621063
default: type = LLM_TYPE_UNKNOWN;
10631064
}
1064-
// for deepstack patch, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...]
1065-
hparams.n_embd = hparams.n_embd * (hparams.n_deepstack_layers + 1);
1065+
// since vision model stacks deepstack features along feature dim
1066+
// we also create a fake "n_embd" for text model to be the main embd + deepstack embds
1067+
hparams.n_embd *= hparams.n_deepstack_layers + 1;
10661068
} break;
10671069
case LLM_ARCH_PHI2:
10681070
{
@@ -3307,11 +3309,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33073309
case LLM_ARCH_QWEN3:
33083310
case LLM_ARCH_QWEN3VL:
33093311
{
3310-
int64_t n_embd = hparams.n_embd;
3311-
// for deepstack features, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...]
3312-
if (arch == LLM_ARCH_QWEN3VL) {
3313-
n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
3314-
}
3312+
// for model loading, the weights only have the main embd
3313+
// so we need to divide by the number of deepstack layers + 1
3314+
// n_embd is const int so we declare a new variable
3315+
int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
33153316
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33163317

33173318
// output
@@ -3347,11 +3348,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33473348
case LLM_ARCH_QWEN3MOE:
33483349
case LLM_ARCH_QWEN3VLMOE:
33493350
{
3350-
// for deepstack features, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...]
3351-
int64_t n_embd = hparams.n_embd;
3352-
if (arch == LLM_ARCH_QWEN3VLMOE) {
3353-
n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
3354-
}
3351+
// for model loading, the weights only have the main embd
3352+
// so we need to divide by the number of deepstack layers + 1
3353+
// n_embd is const int so we declare a new variable
3354+
int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1);
33553355
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33563356

33573357
// output

tools/mtmd/clip-impl.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
4040
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
4141
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
42-
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
42+
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
4343

4444
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4545
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -64,6 +64,7 @@
6464
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
6565
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
6666
#define TN_PATCH_BIAS "v.patch_embd.bias"
67+
#define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s"
6768
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
6869
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
6970
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
@@ -94,9 +95,9 @@
9495
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
9596
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
9697
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
97-
#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack
98-
#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack
99-
#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack
98+
#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack
99+
#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack
100+
#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack
100101

101102
// mimicpmv
102103
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"

tools/mtmd/clip.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ struct clip_layer {
216216
ggml_tensor * q_b = nullptr;
217217
ggml_tensor * v_w = nullptr;
218218
ggml_tensor * v_b = nullptr;
219+
ggml_tensor * qkv_w = nullptr;
220+
ggml_tensor * qkv_b = nullptr;
219221

220222
ggml_tensor * o_w = nullptr;
221223
ggml_tensor * o_b = nullptr;
@@ -927,16 +929,15 @@ struct clip_graph {
927929

928930
// self-attention
929931
{
930-
ggml_tensor * Qcur = ggml_add(ctx0,
931-
ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
932-
ggml_tensor * Kcur = ggml_add(ctx0,
933-
ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
934-
ggml_tensor * Vcur = ggml_add(ctx0,
935-
ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
932+
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
933+
cur = ggml_add(ctx0, cur, layer.qkv_b);
936934

937-
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
938-
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
939-
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
935+
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
936+
cur->nb[1], 0);
937+
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
938+
cur->nb[1], n_embd * sizeof(float));
939+
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
940+
cur->nb[1], 2 * n_embd * sizeof(float));
940941

941942
cb(Qcur, "Qcur", il);
942943
cb(Kcur, "Kcur", il);
@@ -2758,10 +2759,11 @@ struct clip_model_loader {
27582759
model.layers.resize(hparams.n_layer);
27592760
for (int il = 0; il < hparams.n_layer; ++il) {
27602761
auto & layer = model.layers[il];
2761-
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
2762-
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
2763-
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
2762+
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"), false);
2763+
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"), false);
2764+
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"), false);
27642765
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
2766+
layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false);
27652767
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
27662768
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
27672769
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
@@ -2773,6 +2775,7 @@ struct clip_model_loader {
27732775
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
27742776
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
27752777
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
2778+
layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false);
27762779
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
27772780
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
27782781

0 commit comments

Comments
 (0)