Skip to content

Commit affdb0d

Browse files
authored
Merge pull request #28 from JJJYmmm/add_qwen3vl
Add last updates from @JJJYmmm - Adding Support for Qwen3-VL Series
2 parents 955a367 + 3271877 commit affdb0d

File tree

5 files changed

+81
-63
lines changed

5 files changed

+81
-63
lines changed

convert_hf_to_gguf.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4059,7 +4059,9 @@ def __init__(self, *args, **kwargs):
40594059
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
40604060
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")
40614061

4062-
self.deepstack_layers: list[int] = list(self.hparams_vision.get("deepstack_visual_indexes", []))
4062+
self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
4063+
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
4064+
self.is_deepstack_layers[idx] = True
40634065

40644066
def set_gguf_parameters(self):
40654067
super().set_gguf_parameters()
@@ -4075,10 +4077,11 @@ def set_gguf_parameters(self):
40754077
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
40764078
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
40774079

4078-
if self.deepstack_layers:
4079-
self.gguf_writer.add_vision_deepstack_layers(self.deepstack_layers)
4080+
if self.is_deepstack_layers:
4081+
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)
40804082

40814083
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4084+
assert self.hparams_vision is not None
40824085
# Skip text model tensors - they go in the text model file
40834086
if name.startswith("model.language_model.") or name.startswith("lm_head."):
40844087
return []
@@ -4088,7 +4091,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
40884091

40894092
if name.startswith("visual.deepstack_merger_list."):
40904093
prefix, rest = name.split(".", maxsplit=3)[2:]
4091-
idx = int(prefix)
4094+
# prefix is the layer index, convert to absolute clip layer index!
4095+
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
40924096
target = rest
40934097

40944098
tensor_type: gguf.MODEL_TENSOR

gguf-py/gguf/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ class ClipVision:
278278
USE_GELU = "clip.use_gelu"
279279
USE_SILU = "clip.use_silu"
280280
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
281-
DEEPSTACK_LAYERS = "clip.vision.deepstack_layers"
281+
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
282282

283283
class Attention:
284284
HEAD_COUNT = "clip.vision.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,8 +1074,8 @@ def add_vision_projector_scale_factor(self, value: int) -> None:
10741074
def add_vision_n_wa_pattern(self, value: int) -> None:
10751075
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
10761076

1077-
def add_vision_deepstack_layers(self, layers: Sequence[int]) -> None:
1078-
self.add_array(Keys.ClipVision.DEEPSTACK_LAYERS, layers)
1077+
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
1078+
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
10791079

10801080
# audio models
10811081

tools/mtmd/clip-impl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
4040
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
4141
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
42-
#define KEY_DEEPSTACK_LAYERS "clip.vision.deepstack_layers"
42+
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
4343

4444
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4545
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -94,6 +94,9 @@
9494
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
9595
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
9696
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
97+
#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack
98+
#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack
99+
#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack
97100

98101
// mimicpmv
99102
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"

tools/mtmd/clip.cpp

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ struct clip_hparams {
196196
int32_t n_wa_pattern = 0;
197197
int32_t spatial_merge_size = 0;
198198

199-
std::vector<int32_t> deepstack_layers; // qwen3vl deepstack layers
199+
std::vector<bool> is_deepstack_layers; // qwen3vl: whether the layer is a deepstack layer
200200

201201
// audio
202202
int32_t n_mel_bins = 0; // whisper preprocessor
@@ -241,6 +241,14 @@ struct clip_layer {
241241
// layer scale (no bias)
242242
ggml_tensor * ls_1_w = nullptr;
243243
ggml_tensor * ls_2_w = nullptr;
244+
245+
// qwen3vl deepstack merger
246+
ggml_tensor * deepstack_norm_w = nullptr;
247+
ggml_tensor * deepstack_norm_b = nullptr;
248+
ggml_tensor * deepstack_fc1_w = nullptr;
249+
ggml_tensor * deepstack_fc1_b = nullptr;
250+
ggml_tensor * deepstack_fc2_w = nullptr;
251+
ggml_tensor * deepstack_fc2_b = nullptr;
244252
};
245253

246254
struct clip_model {
@@ -361,17 +369,6 @@ struct clip_model {
361369
ggml_tensor * mm_norm_pre_w = nullptr;
362370
ggml_tensor * mm_norm_mid_w = nullptr;
363371

364-
// qwen3vl deepstack
365-
struct deepstack_merger {
366-
ggml_tensor * norm_w = nullptr;
367-
ggml_tensor * norm_b = nullptr;
368-
ggml_tensor * fc1_w = nullptr;
369-
ggml_tensor * fc1_b = nullptr;
370-
ggml_tensor * fc2_w = nullptr;
371-
ggml_tensor * fc2_b = nullptr;
372-
};
373-
std::vector<deepstack_merger> deepstack_mergers;
374-
375372
bool audio_has_avgpool() const {
376373
return proj_type == PROJECTOR_TYPE_QWEN2A
377374
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
@@ -849,7 +846,6 @@ struct clip_graph {
849846
GGML_ASSERT(model.patch_bias != nullptr);
850847
GGML_ASSERT(model.position_embeddings != nullptr);
851848
GGML_ASSERT(model.class_embedding == nullptr);
852-
GGML_ASSERT(!hparams.deepstack_layers.empty());
853849

854850
const int batch_size = 1;
855851
const int n_pos = n_patches;
@@ -986,20 +982,14 @@ struct clip_graph {
986982
cur = ggml_add(ctx0, inpL, cur);
987983
cb(cur, "layer_out", il);
988984

989-
if (std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) != hparams.deepstack_layers.end()) {
990-
const int deepstack_idx = std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) - hparams.deepstack_layers.begin();
991-
auto & merger = model.deepstack_mergers[deepstack_idx];
992-
ggml_tensor * feat = ggml_dup(ctx0, cur);
993-
feat = ggml_reshape_3d(ctx0, feat, n_embd * merge_factor, n_pos / merge_factor, batch_size);
994-
995-
feat = build_norm(feat, merger.norm_w, merger.norm_b, norm_t, eps, il);
996-
feat = ggml_mul_mat(ctx0, merger.fc1_w, feat);
997-
feat = ggml_add(ctx0, feat, merger.fc1_b);
998-
999-
feat = ggml_gelu(ctx0, feat);
1000-
1001-
feat = ggml_mul_mat(ctx0, merger.fc2_w, feat);
1002-
feat = ggml_add(ctx0, feat, merger.fc2_b);
985+
if (hparams.is_deepstack_layers[il]) {
986+
ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size);
987+
feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il);
988+
feat = build_ffn(feat,
989+
layer.deepstack_fc1_w, layer.deepstack_fc1_b,
990+
nullptr, nullptr,
991+
layer.deepstack_fc2_w, layer.deepstack_fc2_b,
992+
ffn_op_type::FFN_GELU, il);
1003993

1004994
if(!deepstack_features) {
1005995
deepstack_features = feat;
@@ -1021,15 +1011,11 @@ struct clip_graph {
10211011
ggml_tensor * embeddings = inpL;
10221012
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
10231013

1024-
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
1025-
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
1026-
1027-
// GELU activation
1028-
embeddings = ggml_gelu(ctx0, embeddings);
1029-
1030-
// Second linear layer
1031-
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
1032-
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
1014+
embeddings = build_ffn(embeddings,
1015+
model.mm_0_w, model.mm_0_b,
1016+
nullptr, nullptr,
1017+
model.mm_1_w, model.mm_1_b,
1018+
ffn_op_type::FFN_GELU, -1);
10331019

10341020
embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
10351021

@@ -2578,6 +2564,9 @@ struct clip_model_loader {
25782564
hparams.vision_feature_layer.insert(layer);
25792565
}
25802566

2567+
// set default deepstack layers to false
2568+
hparams.is_deepstack_layers.resize(hparams.n_layer, false);
2569+
25812570
// model-specific params
25822571
switch (model.proj_type) {
25832572
case PROJECTOR_TYPE_MINICPMV:
@@ -2640,7 +2629,7 @@ struct clip_model_loader {
26402629
hparams.image_size = 1024; // still need this?
26412630
hparams.warmup_image_size = hparams.patch_size * 8;
26422631
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
2643-
get_arr_int(KEY_DEEPSTACK_LAYERS, hparams.deepstack_layers, false);
2632+
get_arr_bool(KEY_IS_DEEPSTACK_LAYERS, hparams.is_deepstack_layers, false);
26442633
} break;
26452634
case PROJECTOR_TYPE_LLAMA4:
26462635
{
@@ -2683,10 +2672,19 @@ struct clip_model_loader {
26832672
if (hparams.spatial_merge_size > 0) {
26842673
LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size);
26852674
}
2686-
if (!hparams.deepstack_layers.empty()) {
2687-
LOG_INF("%s: deepstack_layers: ", __func__);
2688-
for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) {
2689-
LOG_CNT("%d%s", hparams.deepstack_layers[i], i < hparams.deepstack_layers.size() - 1 ? ", " : "\n");
2675+
if (!hparams.is_deepstack_layers.empty()) {
2676+
LOG_INF("%s: deepstack enabled layers: ", __func__);
2677+
bool first = true;
2678+
for (size_t i = 0; i < hparams.is_deepstack_layers.size(); ++i) {
2679+
if (hparams.is_deepstack_layers[i]) {
2680+
LOG_CNT("%s%zu", first ? "" : ", ", i);
2681+
first = false;
2682+
}
2683+
}
2684+
if (first) {
2685+
LOG_CNT("none\n");
2686+
} else {
2687+
LOG_CNT("\n");
26902688
}
26912689
}
26922690
} else if (is_audio) {
@@ -2786,6 +2784,17 @@ struct clip_model_loader {
27862784
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
27872785
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
27882786

2787+
2788+
// qwen3vl deepstack layer
2789+
if (hparams.is_deepstack_layers[il]) {
2790+
layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false);
2791+
layer.deepstack_norm_b = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "bias"), false);
2792+
layer.deepstack_fc1_w = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "weight"), false);
2793+
layer.deepstack_fc1_b = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "bias"), false);
2794+
layer.deepstack_fc2_w = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "weight"), false);
2795+
layer.deepstack_fc2_b = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "bias"), false);
2796+
}
2797+
27892798
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
27902799
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
27912800
bool is_ffn_swapped = (
@@ -2927,19 +2936,6 @@ struct clip_model_loader {
29272936
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
29282937
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
29292938
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
2930-
2931-
if (!hparams.deepstack_layers.empty()) {
2932-
model.deepstack_mergers.resize(hparams.deepstack_layers.size());
2933-
for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) {
2934-
auto & merger = model.deepstack_mergers[i];
2935-
merger.norm_w = get_tensor(string_format("v.deepstack.%d.norm.weight", (int)i), false);
2936-
merger.norm_b = get_tensor(string_format("v.deepstack.%d.norm.bias", (int)i), false);
2937-
merger.fc1_w = get_tensor(string_format("v.deepstack.%d.fc1.weight", (int)i), false);
2938-
merger.fc1_b = get_tensor(string_format("v.deepstack.%d.fc1.bias", (int)i), false);
2939-
merger.fc2_w = get_tensor(string_format("v.deepstack.%d.fc2.weight", (int)i), false);
2940-
merger.fc2_b = get_tensor(string_format("v.deepstack.%d.fc2.bias", (int)i), false);
2941-
}
2942-
}
29432939
} break;
29442940
case PROJECTOR_TYPE_GEMMA3:
29452941
{
@@ -3156,6 +3152,21 @@ struct clip_model_loader {
31563152
}
31573153
}
31583154

3155+
void get_arr_bool(const std::string & key, std::vector<bool> & output, bool required = true) {
3156+
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
3157+
if (i < 0) {
3158+
if (required) throw std::runtime_error("Key not found: " + key);
3159+
return;
3160+
}
3161+
3162+
const int n = gguf_get_arr_n(ctx_gguf.get(), i);
3163+
output.resize(n);
3164+
const bool * values = (const bool *)gguf_get_arr_data(ctx_gguf.get(), i);
3165+
for (int i = 0; i < n; ++i) {
3166+
output[i] = values[i];
3167+
}
3168+
}
3169+
31593170
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
31603171
auto & hparams = model.hparams;
31613172
for (int x = 1; x <= max_patches_per_side; x++) {
@@ -4662,7 +4673,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
46624673
case PROJECTOR_TYPE_QWEN25VL:
46634674
return ctx->model.mm_1_b->ne[0];
46644675
case PROJECTOR_TYPE_QWEN3VL:
4665-
return ctx->model.mm_1_b->ne[0] * ((int)ctx->model.hparams.deepstack_layers.size() + 1); // main path + deepstack paths
4676+
return ctx->model.mm_1_b->ne[0] * (1 + std::count(ctx->model.hparams.is_deepstack_layers.begin(), ctx->model.hparams.is_deepstack_layers.end(), true)); // main path + deepstack paths
46664677
case PROJECTOR_TYPE_GEMMA3:
46674678
return ctx->model.mm_input_proj_w->ne[0];
46684679
case PROJECTOR_TYPE_IDEFICS3:

0 commit comments

Comments
 (0)