Skip to content

Commit 108c585

Browse files
committed
kimi-vl: support arbitrary image size
1 parent c3597d4 commit 108c585

File tree

4 files changed

+440
-51
lines changed

4 files changed

+440
-51
lines changed

models/kimi.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace vit
2323
public:
2424
Learnable2DInterpPosEmb() {}
2525
Learnable2DInterpPosEmb(InitContext *ctx, int height, int width, int dim)
26-
: pos_emb(ggml::new_tensor_3d(ctx, ggml::type::GGML_TYPE_F32, dim, width, height))
26+
: pos_emb(ggml::new_tensor_3d(ctx, ggml::type::GGML_TYPE_F32, dim, height, width))
2727
{
2828
}
2929

@@ -36,7 +36,22 @@ namespace vit
3636
{
3737
CHATLLM_CHECK(ggml::get_dim(input, 3) == 1);
3838

39-
auto output = ggml::add(ctx, input, pos_emb);
39+
//input: ggml[dim, h, w]
40+
41+
ggml::tensor * interp = nullptr;
42+
if ((ggml::get_dim(input, 1) == ggml::get_dim(pos_emb, 1)) && (ggml::get_dim(input, 2) == ggml::get_dim(pos_emb, 2)))
43+
{
44+
interp = pos_emb;
45+
}
46+
else
47+
{
48+
auto permuted = ggml::permute(ctx, pos_emb, 2, 0, 1, 3);
49+
interp = ggml::interpolate(ctx, permuted, ggml::InterpolateMode::Bicubic,
50+
ggml::get_dim(input, 1), ggml::get_dim(input, 2), ggml::get_dim(permuted, 2), ggml::get_dim(permuted, 3));
51+
interp = ggml::permute(ctx, interp, 1, 2, 0, 3);
52+
}
53+
54+
auto output = ggml::add(ctx, input, interp);
4055

4156
return output;
4257
}
@@ -106,28 +121,26 @@ namespace vit
106121
void before_forward(ComputeContext *ctx, const int n_past, const int qlen) override
107122
{
108123
const int len = grid_h * grid_w;
124+
std::vector<int> v_pos_h;
109125

110126
CHATLLM_CHECK(len <= max_length);
111127

112128
ggml::set_dim(pos, 0, len);
113129
ggml::set_dim(pos_h, 0, len);
114130

131+
v_pos_h.resize(len);
115132

116-
for (int i = 0; i < grid_h; i++)
133+
for (int i = 0; i < grid_w; i++)
117134
{
118-
for (int j = 0; j < grid_w; j++)
119-
v_pos[i * grid_w + j] = j;
120-
}
121-
122-
Backend::write_tensor_data(pos, v_pos.data(), 0, len * sizeof(v_pos[0]));
123-
124-
for (int i = 0; i < grid_h; i++)
125-
{
126-
for (int j = 0; j < grid_w; j++)
127-
v_pos[i * grid_w + j] = i;
135+
for (int j = 0; j < grid_h; j++)
136+
{
137+
v_pos [i * grid_h + j] = j;
138+
v_pos_h[i * grid_h + j] = i;
139+
}
128140
}
129141

130-
Backend::write_tensor_data(pos_h, v_pos.data(), 0, len * sizeof(v_pos[0]));
142+
Backend::write_tensor_data(pos, v_pos.data(), 0, len * sizeof(v_pos[0]));
143+
Backend::write_tensor_data(pos_h, v_pos_h.data(), 0, len * sizeof(v_pos[0]));
131144
}
132145

133146
ggml::tensor *apply_2d_rope(ComputeContext *ctx, ggml::tensor *hidden, int hidden_size, ggml::tensor *pos_w, ggml::tensor *pos_h) const
@@ -185,11 +198,13 @@ namespace vit
185198
linear_1(ctx, hidden_size, hidden_size),
186199
linear_2(ctx, hidden_size, lm_hidden_size)
187200
{
201+
memcpy(merge_param.merge_kernel_size, config.merge_kernel_size, sizeof(merge_param.merge_kernel_size));
188202
}
189203

190204
ggml::tensor *forward(ComputeContext *ctx, ggml::tensor *image_features) override
191205
{
192-
auto output = pre_norm.forward(ctx, image_features);
206+
auto output = merge_patch(ctx, image_features);
207+
output = pre_norm.forward(ctx, output);
193208
output = ggml::reshape_2d(ctx, output,
194209
hidden_size, ggml::get_dim(output, 1) / (hidden_size / ggml::get_dim(output, 0)) * ggml::get_dim(output, 2) * ggml::get_dim(output, 3));
195210
output = linear_1.forward(ctx, output);
@@ -214,11 +229,18 @@ namespace vit
214229
linear_2.load(path + "linear_2.", loader);
215230
}
216231

232+
protected:
233+
ggml::tensor *merge_patch(ComputeContext *ctx, ggml::tensor *x)
234+
{
235+
auto reshaped_seq = ggml::merge_patch(ctx, x, &merge_param);
236+
return reshaped_seq;
237+
}
217238
public:
218239
const int hidden_size;
219240
LayerNorm pre_norm;
220241
Linear linear_1;
221242
Linear linear_2;
243+
ggml::merge_patch_param merge_param;
222244
};
223245

224246
class VisionTransformer : public Block
@@ -240,7 +262,6 @@ namespace vit
240262
layer->set_id(layer_id);
241263
layers.emplace_back(layer);
242264
}
243-
memcpy(merge_param.merge_kernel_size, config.merge_kernel_size, sizeof(merge_param.merge_kernel_size));
244265
}
245266

246267
int64_t get_param_num(bool effective_only) const override
@@ -269,17 +290,11 @@ namespace vit
269290
loaded = true;
270291
}
271292

272-
ggml::tensor *merge_patch(ComputeContext *ctx, ggml::tensor *x, int grid_h, int grid_w)
273-
{
274-
merge_param.grid_h = grid_h;
275-
merge_param.grid_w = grid_w;
276-
277-
auto reshaped_seq = ggml::merge_patch(ctx, x, &merge_param);
278-
return reshaped_seq;
279-
}
280-
281293
ggml::tensor *forward(ComputeContext *ctx, ggml::tensor *input, int grid_h, int grid_w)
282294
{
295+
multi_modal_projector.merge_param.grid_h = grid_h;
296+
multi_modal_projector.merge_param.grid_w = grid_w;
297+
283298
auto output = embeddings.forward(ctx, input, grid_h, grid_w);
284299

285300
for (size_t i = 0; i < layers.size(); i++)
@@ -289,7 +304,6 @@ namespace vit
289304
output = layers[i]->forward(ctx, output, 0);
290305
}
291306
output = post_layernorm.forward(ctx, output);
292-
output = merge_patch(ctx, output, grid_h, grid_w);
293307
output = multi_modal_projector.forward(ctx, output);
294308
return output;
295309
}
@@ -305,7 +319,6 @@ namespace vit
305319
MultiModalProjector multi_modal_projector;
306320
protected:
307321
bool loaded;
308-
ggml::merge_patch_param merge_param;
309322
};
310323

311324
class VisualEmbeddingGeneration
@@ -703,9 +716,6 @@ namespace vl
703716
vision::MaxGridHeight param3(512);
704717
vision::MaxGridWidth param4(512);
705718

706-
// TODO: cubic interpolation not ready yet. image size fixed.
707-
vision::Resize resize(896, 896);
708-
709719
vision::image_load(piece.content.c_str(), pixels, w, h, patch_size, vision::PaddingMode::Black);
710720

711721
std::vector<float> scaled;

src/backend.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace chatllm
1818
ggml::type type,
1919
int n_dims,
2020
const int64_t *ne);
21+
tensor *init_tensor(ggml::tensor *tensor,
22+
ggml::type type,
23+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3);
2124
void change_type(ggml::tensor *tensor, ggml::type type);
2225

2326
size_t element_size(const ggml::tensor *tensor);

0 commit comments

Comments
 (0)