Skip to content

Commit 2bfa7a4

Browse files
committed
InternVL3-1B working
1 parent e309f16 commit 2bfa7a4

File tree

4 files changed

+132
-18
lines changed

4 files changed

+132
-18
lines changed

convert_hf_to_gguf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2724,13 +2724,18 @@ def set_gguf_parameters(self):
27242724
super().set_gguf_parameters()
27252725
hparams = self.hparams
27262726
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
2727+
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
27272728
# hidden_act
27282729
if hparams["hidden_act"] == "silu":
27292730
self.gguf_writer.add_vision_use_silu(True)
27302731
elif hparams["hidden_act"] == "gelu":
27312732
self.gguf_writer.add_vision_use_gelu(True)
27322733
else:
27332734
raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
2735+
# downsample_ratio
2736+
downsample_ratio = self.global_config.get("downsample_ratio")
2737+
assert downsample_ratio is not None
2738+
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
27342739

27352740
def tensor_force_quant(self, name, new_name, bid, n_dims):
27362741
del bid, name, n_dims # unused
@@ -2747,7 +2752,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27472752
# correct name
27482753
if name.startswith("vision_model"):
27492754
name = "vision_tower." + name
2750-
if ".ls" in name and not name.endswith(".weight"):
2755+
if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
27512756
name += ".weight"
27522757
# split QKV tensors if needed
27532758
if ".qkv." in name:

tools/mtmd/clip-impl.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@
3333
#define KEY_PROJ_TYPE "clip.projector_type"
3434
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
3535

36-
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
37-
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
38-
3936
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4037
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
4138
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
@@ -60,8 +57,10 @@
6057
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
6158
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
6259
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
63-
#define TN_LN_1 "%s.blk.%d.ln1.%s"
64-
#define TN_LN_2 "%s.blk.%d.ln2.%s"
60+
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
61+
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
62+
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
63+
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
6564
#define TN_LN_PRE "%s.pre_ln.%s"
6665
#define TN_LN_POST "%s.post_ln.%s"
6766
#define TN_LLAVA_PROJ "mm.%d.%s"
@@ -105,6 +104,7 @@ enum projector_type {
105104
PROJECTOR_TYPE_IDEFICS3,
106105
PROJECTOR_TYPE_PIXTRAL,
107106
PROJECTOR_TYPE_QWEN25VL,
107+
PROJECTOR_TYPE_INTERNVL,
108108
PROJECTOR_TYPE_UNKNOWN,
109109
};
110110

@@ -119,6 +119,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
119119
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
120120
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
121121
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
122+
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
122123
};
123124

124125
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/mtmd/clip.cpp

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ struct clip_layer {
215215
// layernorm 2
216216
ggml_tensor * ln_2_w = nullptr;
217217
ggml_tensor * ln_2_b = nullptr;
218+
219+
// layer scale (no bias)
220+
ggml_tensor * ls_1_w = nullptr;
221+
ggml_tensor * ls_2_w = nullptr;
218222
};
219223

220224
struct clip_vision_model {
@@ -589,6 +593,9 @@ struct clip_graph {
589593

590594
// Qwen2VL and Qwen2.5VL use M-RoPE
591595
ggml_cgraph * build_qwen2vl() {
596+
GGML_ASSERT(model.patch_bias == nullptr);
597+
GGML_ASSERT(model.class_embedding == nullptr);
598+
592599
const int batch_size = 1;
593600
const bool use_window_attn = hparams.n_wa_pattern > 0;
594601
const int n_wa_pattern = hparams.n_wa_pattern;
@@ -625,10 +632,6 @@ struct clip_graph {
625632
n_embd, n_patches_x * n_patches_y, batch_size);
626633
}
627634

628-
if (model.patch_bias) {
629-
inp = ggml_add(ctx0, inp, model.patch_bias);
630-
}
631-
632635
ggml_tensor * inpL = inp;
633636
ggml_tensor * window_mask = nullptr;
634637
ggml_tensor * window_idx = nullptr;
@@ -859,6 +862,65 @@ struct clip_graph {
859862
return gf;
860863
}
861864

865+
ggml_cgraph * build_internvl() {
866+
GGML_ASSERT(model.class_embedding != nullptr);
867+
GGML_ASSERT(model.position_embeddings != nullptr);
868+
869+
const int n_pos = n_patches + 1;
870+
ggml_tensor * inp = build_inp();
871+
872+
// add CLS token
873+
inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
874+
875+
ggml_tensor * cur = build_vit(
876+
inp, n_pos,
877+
NORM_TYPE_NORMAL,
878+
hparams.ffn_op,
879+
model.position_embeddings,
880+
nullptr);
881+
882+
// remove CLS token
883+
cur = ggml_view_2d(ctx0, cur,
884+
n_embd, n_patches,
885+
ggml_row_size(cur->type, n_embd), 0);
886+
887+
// pixel shuffle
888+
{
889+
const int scale_factor = model.hparams.proj_scale_factor;
890+
const int bsz = 1; // batch size, always 1 for now since we don't support batching
891+
const int height = n_patches_y;
892+
const int width = n_patches_x;
893+
GGML_ASSERT(scale_factor > 0);
894+
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
895+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
896+
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
897+
n_embd * scale_factor * scale_factor,
898+
height / scale_factor,
899+
width / scale_factor,
900+
bsz);
901+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
902+
// flatten to 2D
903+
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
904+
n_embd * scale_factor * scale_factor,
905+
cur->ne[1] * cur->ne[2]);
906+
}
907+
908+
// projector (always using GELU activation)
909+
{
910+
cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
911+
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
912+
cur = ggml_add(ctx0, cur, model.mm_1_b);
913+
cur = ggml_gelu(ctx0, cur);
914+
cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
915+
cur = ggml_add(ctx0, cur, model.mm_3_b);
916+
}
917+
918+
// build the graph
919+
ggml_build_forward_expand(gf, cur);
920+
921+
return gf;
922+
}
923+
862924
// this graph is used by llava, granite and glm
863925
// due to having embedding_stack (used by granite), we cannot reuse build_vit
864926
ggml_cgraph * build_llava() {
@@ -1260,11 +1322,6 @@ struct clip_graph {
12601322
ggml_tensor * learned_pos_embd,
12611323
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
12621324
) {
1263-
if (model.patch_bias) {
1264-
inp = ggml_add(ctx0, inp, model.patch_bias);
1265-
cb(inp, "patch_bias", -1);
1266-
}
1267-
12681325
if (learned_pos_embd) {
12691326
inp = ggml_add(ctx0, inp, learned_pos_embd);
12701327
cb(inp, "pos_embed", -1);
@@ -1324,6 +1381,11 @@ struct clip_graph {
13241381
cb(cur, "attn_out", il);
13251382
}
13261383

1384+
if (layer.ls_1_w) {
1385+
cur = ggml_mul(ctx0, cur, layer.ls_1_w);
1386+
cb(cur, "attn_out_scaled", il);
1387+
}
1388+
13271389
// re-add the layer input, e.g., residual
13281390
cur = ggml_add(ctx0, cur, inpL);
13291391

@@ -1344,6 +1406,11 @@ struct clip_graph {
13441406

13451407
cb(cur, "ffn_out", il);
13461408

1409+
if (layer.ls_2_w) {
1410+
cur = ggml_mul(ctx0, cur, layer.ls_2_w);
1411+
cb(cur, "ffn_out_scaled", il);
1412+
}
1413+
13471414
// residual 2
13481415
cur = ggml_add(ctx0, inpL, cur);
13491416
cb(cur, "layer_out", il);
@@ -1365,6 +1432,10 @@ struct clip_graph {
13651432
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
13661433
inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
13671434
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
1435+
if (model.patch_bias) {
1436+
inp = ggml_add(ctx0, inp, model.patch_bias);
1437+
cb(inp, "patch_bias", -1);
1438+
}
13681439
return inp;
13691440
}
13701441

@@ -1627,6 +1698,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
16271698
{
16281699
res = graph.build_minicpmv();
16291700
} break;
1701+
case PROJECTOR_TYPE_INTERNVL:
1702+
{
1703+
res = graph.build_internvl();
1704+
} break;
16301705
default:
16311706
{
16321707
res = graph.build_llava();
@@ -1790,6 +1865,7 @@ struct clip_model_loader {
17901865
}
17911866
} break;
17921867
case PROJECTOR_TYPE_IDEFICS3:
1868+
case PROJECTOR_TYPE_INTERNVL:
17931869
{
17941870
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
17951871
} break;
@@ -1897,14 +1973,17 @@ struct clip_model_loader {
18971973
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
18981974
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
18991975
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
1976+
layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
1977+
layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
1978+
19001979
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
19011980
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
19021981
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
19031982
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
19041983
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
19051984
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
19061985

1907-
// new naming
1986+
// ffn
19081987
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
19091988
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
19101989
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
@@ -2052,6 +2131,15 @@ struct clip_model_loader {
20522131
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
20532132
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
20542133
} break;
2134+
case PROJECTOR_TYPE_INTERNVL:
2135+
{
2136+
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
2137+
vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
2138+
vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
2139+
vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
2140+
vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
2141+
vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
2142+
} break;
20552143
default:
20562144
GGML_ASSERT(false && "unknown projector type");
20572145
}
@@ -2838,7 +2926,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
28382926
}
28392927
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
28402928
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
2841-
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
2929+
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
2930+
|| ctx->proj_type == PROJECTOR_TYPE_INTERNVL) {
28422931
clip_image_u8 resized_image;
28432932
int sz = params.image_size;
28442933
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
@@ -3013,7 +3102,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
30133102
int n_per_side = params.image_size / params.patch_size;
30143103
int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
30153104
n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
3016-
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
3105+
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) {
3106+
// both W and H are divided by proj_scale_factor
30173107
n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
30183108
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
30193109
int n_merge = params.spatial_merge_size;
@@ -3408,6 +3498,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
34083498
} break;
34093499
case PROJECTOR_TYPE_GEMMA3:
34103500
case PROJECTOR_TYPE_IDEFICS3:
3501+
case PROJECTOR_TYPE_INTERNVL:
34113502
{
34123503
// do nothing
34133504
} break;
@@ -3434,6 +3525,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
34343525
// the last node is the embedding tensor
34353526
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
34363527

3528+
// sanity check (only support batch size of 1 for now)
3529+
const int n_tokens_out = embeddings->ne[1];
3530+
const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
3531+
if (n_tokens_out != expected_n_tokens_out) {
3532+
LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
3533+
GGML_ABORT("Invalid number of output tokens");
3534+
}
3535+
34373536
// copy the embeddings to the location passed by the user
34383537
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
34393538

@@ -3604,6 +3703,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
36043703
return ctx->vision_model.mm_input_proj_w->ne[0];
36053704
case PROJECTOR_TYPE_IDEFICS3:
36063705
return ctx->vision_model.projection->ne[1];
3706+
case PROJECTOR_TYPE_INTERNVL:
3707+
return ctx->vision_model.mm_3_w->ne[1];
36073708
default:
36083709
GGML_ABORT("Unknown projector type");
36093710
}

tools/mtmd/mtmd.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
252252

253253
}
254254

255+
else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
256+
// <img> ... (image embeddings) ... </img>
257+
marker_modified = "<img>" + ctx->image_marker + "</img>";
258+
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
259+
260+
}
261+
255262
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
256263
// for glm-edge, BOI and EOI token's embeddings are not present in the text model
257264

0 commit comments

Comments
 (0)