Skip to content

Commit 56cf4ca

Browse files
committed
support kimi vision model
1 parent a7dcaa9 commit 56cf4ca

File tree

3 files changed

+127
-5
lines changed

3 files changed

+127
-5
lines changed

convert_hf_to_gguf.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8498,6 +8498,10 @@ def set_gguf_parameters(self):
84988498
super().set_gguf_parameters()
84998499
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
85008500
self.gguf_writer.add_vision_use_gelu(True)
8501+
self.gguf_writer.add_vision_projector_scale_factor(2)
8502+
# eps is the same as pytorch's default value
8503+
assert self.hparams_vision is not None
8504+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
85018505

85028506
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
85038507
del bid # unused
@@ -8506,8 +8510,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
85068510
if is_vision_tensor:
85078511
if "pos_emb.weight" in name:
85088512
data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
8509-
elif "wqkv.weight" in name or "wqkv.bias" in name:
8510-
wq, wk, wv = data_torch.chunk(3, dim=-1)
8513+
elif "wqkv" in name:
8514+
split_dim = 0 if "weight" in name else -1
8515+
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
85118516
return [
85128517
(self.map_tensor_name(name.replace("wqkv", "wq")), wq),
85138518
(self.map_tensor_name(name.replace("wqkv", "wk")), wk),

tools/mtmd/clip-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ enum projector_type {
135135
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
136136
PROJECTOR_TYPE_VOXTRAL,
137137
PROJECTOR_TYPE_LFM2,
138+
PROJECTOR_TYPE_KIMIVL,
138139
PROJECTOR_TYPE_UNKNOWN,
139140
};
140141

@@ -156,6 +157,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
156157
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
157158
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
158159
{ PROJECTOR_TYPE_LFM2, "lfm2"},
160+
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
159161
};
160162

161163
static projector_type clip_projector_type_from_string(const std::string & str) {

tools/mtmd/clip.cpp

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,7 @@ struct clip_graph {
10861086
n_patches_x / scale_factor,
10871087
n_patches_y / scale_factor,
10881088
bsz);
1089-
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1089+
//cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
10901090
// flatten to 2D
10911091
cur = ggml_cont_2d(ctx0, cur,
10921092
n_embd * scale_factor * scale_factor,
@@ -1113,6 +1113,92 @@ struct clip_graph {
11131113
return gf;
11141114
}
11151115

1116+
ggml_cgraph * build_kimivl() {
1117+
// 2D input positions
1118+
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
1119+
ggml_set_name(pos_h, "pos_h");
1120+
ggml_set_input(pos_h);
1121+
1122+
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
1123+
ggml_set_name(pos_w, "pos_w");
1124+
ggml_set_input(pos_w);
1125+
1126+
ggml_tensor * learned_pos_embd = resize_position_embeddings();
1127+
1128+
// build ViT with 2D position embeddings
1129+
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
1130+
// first half is X axis and second half is Y axis
1131+
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
1132+
};
1133+
1134+
ggml_tensor * inp = build_inp();
1135+
ggml_tensor * cur = build_vit(
1136+
inp, n_patches,
1137+
NORM_TYPE_NORMAL,
1138+
hparams.ffn_op,
1139+
learned_pos_embd,
1140+
add_pos);
1141+
1142+
cb(cur, "vit_out", -1);
1143+
1144+
{
1145+
// pixel unshuffle block
1146+
const int scale_factor = model.hparams.proj_scale_factor;
1147+
GGML_ASSERT(scale_factor > 1);
1148+
1149+
const int n_embd = cur->ne[0];
1150+
int width = img.nx / patch_size;
1151+
int height = img.ny / patch_size;
1152+
1153+
// pad width and height to factor
1154+
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
1155+
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
1156+
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
1157+
if (pad_width || pad_height) {
1158+
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
1159+
width += pad_width;
1160+
height += pad_height;
1161+
}
1162+
1163+
// unshuffle h
1164+
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
1165+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1166+
1167+
// unshuffle w
1168+
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
1169+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1170+
1171+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
1172+
cb(cur, "pixel_unshuffle", -1);
1173+
1174+
// projection norm
1175+
int proj_inp_dim = cur->ne[0];
1176+
cur = ggml_view_2d(ctx0, cur,
1177+
n_embd, cur->ne[1] * scale_factor * scale_factor,
1178+
ggml_row_size(cur->type, n_embd), 0);
1179+
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
1180+
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
1181+
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
1182+
cur = ggml_view_2d(ctx0, cur,
1183+
proj_inp_dim, cur->ne[1] / scale_factor / scale_factor,
1184+
ggml_row_size(cur->type, proj_inp_dim), 0);
1185+
cb(cur, "proj_inp_normed", -1);
1186+
1187+
// projection mlp
1188+
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
1189+
cur = ggml_add(ctx0, cur, model.mm_1_b);
1190+
cur = ggml_gelu(ctx0, cur);
1191+
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
1192+
cur = ggml_add(ctx0, cur, model.mm_2_b);
1193+
cb(cur, "proj_out", -1);
1194+
}
1195+
1196+
// build the graph
1197+
ggml_build_forward_expand(gf, cur);
1198+
1199+
return gf;
1200+
}
1201+
11161202
// this graph is used by llava, granite and glm
11171203
// due to having embedding_stack (used by granite), we cannot reuse build_vit
11181204
ggml_cgraph * build_llava() {
@@ -2063,6 +2149,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
20632149
{
20642150
res = graph.build_whisper_enc();
20652151
} break;
2152+
case PROJECTOR_TYPE_KIMIVL:
2153+
{
2154+
res = graph.build_kimivl();
2155+
} break;
20662156
default:
20672157
{
20682158
res = graph.build_llava();
@@ -2311,6 +2401,12 @@ struct clip_model_loader {
23112401
hparams.image_size = 1024;
23122402
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
23132403
} break;
2404+
case PROJECTOR_TYPE_KIMIVL:
2405+
{
2406+
hparams.rope_theta = 10000.0f;
2407+
hparams.warmup_image_size = hparams.patch_size * 8;
2408+
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
2409+
} break;
23142410
case PROJECTOR_TYPE_GEMMA3:
23152411
{
23162412
// default value (used by all model sizes in gemma 3 family)
@@ -2475,7 +2571,17 @@ struct clip_model_loader {
24752571

24762572
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
24772573
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
2478-
if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
2574+
bool is_ffn_swapped = (
2575+
model.proj_type == PROJECTOR_TYPE_MLP
2576+
|| model.proj_type == PROJECTOR_TYPE_MLP_NORM
2577+
|| model.proj_type == PROJECTOR_TYPE_LDP
2578+
|| model.proj_type == PROJECTOR_TYPE_LDPV2
2579+
|| model.proj_type == PROJECTOR_TYPE_QWEN2VL
2580+
|| model.proj_type == PROJECTOR_TYPE_QWEN25VL
2581+
|| model.proj_type == PROJECTOR_TYPE_GLM_EDGE
2582+
|| model.proj_type == PROJECTOR_TYPE_GEMMA3
2583+
) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd;
2584+
if (is_ffn_swapped) {
24792585
// swap up and down weights
24802586
ggml_tensor * tmp = layer.ff_up_w;
24812587
layer.ff_up_w = layer.ff_down_w;
@@ -2484,6 +2590,9 @@ struct clip_model_loader {
24842590
tmp = layer.ff_up_b;
24852591
layer.ff_up_b = layer.ff_down_b;
24862592
layer.ff_down_b = tmp;
2593+
if (il == 0) {
2594+
LOG_WRN("%s: ffn up/down are swapped\n", __func__);
2595+
}
24872596
}
24882597
}
24892598

@@ -2602,6 +2711,7 @@ struct clip_model_loader {
26022711
model.projection = get_tensor(TN_MM_PROJECTOR);
26032712
} break;
26042713
case PROJECTOR_TYPE_LFM2:
2714+
case PROJECTOR_TYPE_KIMIVL:
26052715
{
26062716
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
26072717
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@@ -3481,7 +3591,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
34813591
res_imgs->entries.push_back(std::move(img_f32));
34823592
return true;
34833593

3484-
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
3594+
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
3595+
|| ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
3596+
) {
34853597
clip_image_u8 resized_image;
34863598
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
34873599
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
@@ -3704,6 +3816,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
37043816
case PROJECTOR_TYPE_INTERNVL:
37053817
case PROJECTOR_TYPE_LLAMA4:
37063818
case PROJECTOR_TYPE_LFM2:
3819+
case PROJECTOR_TYPE_KIMIVL:
37073820
{
37083821
// both W and H are divided by proj_scale_factor
37093822
int scale_factor = ctx->model.hparams.proj_scale_factor;
@@ -4091,6 +4204,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
40914204
set_input_i32("positions", positions);
40924205
} break;
40934206
case PROJECTOR_TYPE_PIXTRAL:
4207+
case PROJECTOR_TYPE_KIMIVL:
40944208
{
40954209
// set the 2D positions
40964210
int n_patches_per_col = image_size_width / patch_size;
@@ -4245,6 +4359,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
42454359
case PROJECTOR_TYPE_QWEN2A:
42464360
return ctx->model.mm_fc_w->ne[1];
42474361
case PROJECTOR_TYPE_LFM2:
4362+
case PROJECTOR_TYPE_KIMIVL:
42484363
return ctx->model.mm_2_w->ne[1];
42494364
default:
42504365
GGML_ABORT("Unknown projector type");

0 commit comments

Comments
 (0)