Skip to content

Commit caa7e57

Browse files
committed
add PROJECTOR_TYPE_QWEN2_5_VL
1 parent a3cd0e5 commit caa7e57

File tree

3 files changed

+280
-4
lines changed

3 files changed

+280
-4
lines changed

examples/llava/clip-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ enum projector_type {
107107
PROJECTOR_TYPE_GEMMA3,
108108
PROJECTOR_TYPE_IDEFICS3,
109109
PROJECTOR_TYPE_PIXTRAL,
110+
PROJECTOR_TYPE_QWEN2_5_VL,
110111
PROJECTOR_TYPE_UNKNOWN,
111112
};
112113

@@ -117,6 +118,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
117118
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
118119
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
119120
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
121+
{ PROJECTOR_TYPE_QWEN2_5_VL,"qwen2.5vl_merger"},
120122
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
121123
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
122124
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},

examples/llava/clip.cpp

Lines changed: 276 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,273 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
780780
return gf;
781781
}
782782

783+
static ggml_cgraph * clip_image_build_graph_qwen2_5_vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
784+
const auto & model = ctx->vision_model;
785+
const auto & hparams = model.hparams;
786+
787+
const int image_size = hparams.image_size;
788+
int image_size_width = image_size;
789+
int image_size_height = image_size;
790+
791+
image_size_width = imgs.entries[0]->nx;
792+
image_size_height = imgs.entries[0]->ny;
793+
794+
const int patch_size = hparams.patch_size;
795+
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
796+
const int patches_w = image_size_width / patch_size;
797+
const int patches_h = image_size_height / patch_size;
798+
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
799+
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
800+
const int hidden_size = hparams.hidden_size;
801+
const int n_head = hparams.n_head;
802+
const int d_head = hidden_size / n_head;
803+
const float eps = hparams.eps;
804+
const bool use_window_attn = hparams.full_attn_layers.size() > 0;
805+
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
806+
807+
const int batch_size = imgs.entries.size();
808+
GGML_ASSERT(batch_size == 1);
809+
810+
struct ggml_init_params params = {
811+
/*.mem_size =*/ ctx->buf_compute_meta.size(),
812+
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
813+
/*.no_alloc =*/ true,
814+
};
815+
816+
ggml_context_ptr ctx0_ptr(ggml_init(params));
817+
auto ctx0 = ctx0_ptr.get();
818+
819+
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
820+
821+
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
822+
ggml_set_name(inp_raw, "inp_raw");
823+
ggml_set_input(inp_raw);
824+
825+
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
826+
827+
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
828+
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
829+
830+
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
831+
inp = ggml_add(ctx0, inp, inp_1);
832+
833+
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
834+
inp = ggml_reshape_4d(
835+
ctx0, inp,
836+
hidden_size * 2, patches_w / 2, patches_h, batch_size);
837+
inp = ggml_reshape_4d(
838+
ctx0, inp,
839+
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
840+
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
841+
inp = ggml_reshape_3d(
842+
ctx0, inp,
843+
hidden_size, patches_w * patches_h, batch_size);
844+
845+
if (model.patch_bias) {
846+
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
847+
inp = ggml_add(ctx0, inp, model.patch_bias);
848+
}
849+
struct ggml_tensor * embeddings = inp;
850+
struct ggml_tensor * pos_embed = nullptr;
851+
struct ggml_tensor * window_mask = nullptr;
852+
struct ggml_tensor * window_idx = nullptr;
853+
struct ggml_tensor * inv_window_idx = nullptr;
854+
855+
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
856+
ggml_set_name(positions, "positions");
857+
ggml_set_input(positions);
858+
859+
// pre-layernorm
860+
if (model.pre_ln_w) {
861+
if (ctx->use_rms_norm) {
862+
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
863+
ggml_set_name(embeddings, "pre_ln");
864+
865+
embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
866+
} else {
867+
embeddings = ggml_norm(ctx0, embeddings, eps);
868+
ggml_set_name(embeddings, "pre_ln");
869+
870+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
871+
}
872+
}
873+
874+
std::vector<struct ggml_tensor *> embedding_stack;
875+
const auto & vision_feature_layer = hparams.vision_feature_layer;
876+
877+
if (use_window_attn) {
878+
// handle window attention inputs
879+
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
880+
ggml_set_name(inv_window_idx, "inv_window_idx");
881+
ggml_set_input(inv_window_idx);
882+
// mask for window attention
883+
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
884+
ggml_set_name(window_mask, "window_mask");
885+
ggml_set_input(window_mask);
886+
887+
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
888+
GGML_ASSERT(batch_size == 1);
889+
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
890+
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
891+
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
892+
}
893+
894+
// loop over layers
895+
for (int il = 0; il < ctx->max_feature_layer; il++) {
896+
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
897+
898+
// If this is an embedding feature layer, save the output.
899+
// NOTE: 0 index here refers to the input to the encoder.
900+
if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
901+
embedding_stack.push_back(embeddings);
902+
}
903+
904+
// rmsnorm1
905+
cur = ggml_rms_norm(ctx0, cur, eps);
906+
cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
907+
908+
// self-attention
909+
{
910+
911+
struct ggml_tensor * Q =
912+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
913+
914+
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
915+
Q = ggml_rope_multi(
916+
ctx0, Q, positions, nullptr,
917+
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
918+
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
919+
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
920+
921+
struct ggml_tensor * K =
922+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
923+
924+
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
925+
K = ggml_rope_multi(
926+
ctx0, K, positions, nullptr,
927+
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
928+
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
929+
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
930+
931+
struct ggml_tensor * V =
932+
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
933+
934+
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
935+
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
936+
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
937+
938+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
939+
const bool inlist = std::find(hparams.full_attn_layers.begin(), hparams.full_attn_layers.end(), il) != hparams.full_attn_layers.end();
940+
const bool full_attn = use_window_attn ? inlist : true;
941+
if (full_attn) {
942+
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
943+
} else {
944+
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
945+
}
946+
947+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
948+
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
949+
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
950+
951+
cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
952+
}
953+
954+
// attention output
955+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
956+
957+
// re-add the layer input, e.g., residual
958+
cur = ggml_add(ctx0, cur, embeddings);
959+
960+
embeddings = cur; // embeddings = residual, cur = hidden_states
961+
962+
// rms norm2
963+
cur = ggml_rms_norm(ctx0, cur, eps);
964+
cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
965+
966+
// mlp
967+
// ffn_up
968+
auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
969+
cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b);
970+
971+
auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur);
972+
cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b);
973+
if (ctx->use_gelu) {
974+
cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
975+
} else if (ctx->use_silu) {
976+
cur_gate = ggml_silu_inplace(ctx0, cur_gate);
977+
} else {
978+
cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
979+
}
980+
cur = ggml_mul(ctx0, cur_gate, cur_up);
981+
982+
// ffn_down
983+
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
984+
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
985+
986+
// residual 2
987+
cur = ggml_add(ctx0, embeddings, cur);
988+
989+
embeddings = cur;
990+
}
991+
992+
// post-layernorm
993+
if (model.post_ln_w) {
994+
if (ctx->use_rms_norm) {
995+
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
996+
ggml_set_name(embeddings, "post_ln");
997+
998+
embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
999+
} else {
1000+
embeddings = ggml_norm(ctx0, embeddings, eps);
1001+
ggml_set_name(embeddings, "post_ln");
1002+
1003+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
1004+
}
1005+
}
1006+
1007+
// final layer is a vision feature layer
1008+
if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
1009+
embedding_stack.push_back(embeddings);
1010+
}
1011+
1012+
// If feature layers are explicitly set, stack them (if we have multiple)
1013+
if (!embedding_stack.empty()) {
1014+
embeddings = embedding_stack[0];
1015+
for (size_t i = 1; i < embedding_stack.size(); i++) {
1016+
embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
1017+
}
1018+
}
1019+
1020+
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
1021+
1022+
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
1023+
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
1024+
1025+
// GELU activation
1026+
embeddings = ggml_gelu(ctx0, embeddings);
1027+
1028+
// Second linear layer
1029+
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
1030+
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
1031+
1032+
if (use_window_attn) {
1033+
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
1034+
ggml_set_name(window_idx, "window_idx");
1035+
ggml_set_input(window_idx);
1036+
1037+
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
1038+
GGML_ASSERT(batch_size == 1);
1039+
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
1040+
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
1041+
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
1042+
}
1043+
1044+
// build the graph
1045+
ggml_build_forward_expand(gf, embeddings);
1046+
1047+
return gf;
1048+
}
1049+
7831050
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
7841051
if (!ctx->has_vision_encoder) {
7851052
LOG_ERR("This gguf file seems to have no vision encoder\n");
@@ -1441,6 +1708,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
14411708
{
14421709
res = clip_image_build_graph_pixtral(ctx, imgs);
14431710
} break;
1711+
case PROJECTOR_TYPE_QWEN2_5_VL:
1712+
{
1713+
res = clip_image_build_graph_qwen2_5_vl(ctx, imgs);
1714+
} break;
14441715
default:
14451716
{
14461717
// TODO: we should have one build_* function per model
@@ -1699,7 +1970,7 @@ struct clip_model_loader {
16991970
// legacy naming (the in and out is reversed! don't ask me why)
17001971
layer.ff_i_w = layer.ff_down_w;
17011972
layer.ff_o_w = layer.ff_up_w;
1702-
layer.ff_g_w = layer.ff_gate_b;
1973+
layer.ff_g_w = layer.ff_gate_w;
17031974
layer.ff_i_b = layer.ff_down_b;
17041975
layer.ff_o_b = layer.ff_up_b;
17051976
layer.ff_g_b = layer.ff_gate_b;
@@ -1801,6 +2072,7 @@ struct clip_model_loader {
18012072
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
18022073
} break;
18032074
case PROJECTOR_TYPE_MERGER:
2075+
case PROJECTOR_TYPE_QWEN2_5_VL:
18042076
{
18052077
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
18062078
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
@@ -2754,7 +3026,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
27543026
else if (ctx->minicpmv_version == 4) {
27553027
n_patches = 64;
27563028
}
2757-
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
3029+
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER || ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) {
27583030
int patch_size = params.patch_size * 2;
27593031
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
27603032
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
@@ -3165,7 +3437,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
31653437
}
31663438
}
31673439

3168-
if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn?
3440+
if (hparams.attn_window_size > 0 && ctx->proj_type == PROJECTOR_TYPE_QWEN2_5_VL) {
31693441
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
31703442
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
31713443
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
@@ -3398,6 +3670,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
33983670
case PROJECTOR_TYPE_GLM_EDGE:
33993671
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
34003672
case PROJECTOR_TYPE_MERGER:
3673+
case PROJECTOR_TYPE_QWEN2_5_VL:
34013674
return ctx->vision_model.mm_1_b->ne[0];
34023675
case PROJECTOR_TYPE_GEMMA3:
34033676
return ctx->vision_model.mm_input_proj_w->ne[0];

examples/llava/qwen2_vl_surgery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def main(args):
140140
fout.add_bool("clip.has_text_encoder", False)
141141
fout.add_bool("clip.has_vision_encoder", True)
142142
fout.add_bool("clip.has_qwen2vl_merger", True)
143-
fout.add_string("clip.projector_type", "qwen2vl_merger")
144143

145144
print(cfg.vision_config)
146145
if 'silu' in cfg.vision_config.hidden_act.lower():
@@ -159,7 +158,9 @@ def main(args):
159158
fout.add_uint32("clip.vision.window_size", vcfg.window_size)
160159
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
161160
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
161+
fout.add_string("clip.projector_type", "qwen2.5vl_merger")
162162
else:
163+
fout.add_string("clip.projector_type", "qwen2vl_merger")
163164
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
164165
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
165166

0 commit comments

Comments
 (0)