@@ -23,7 +23,7 @@ namespace vit
2323 public:
2424 Learnable2DInterpPosEmb () {}
2525 Learnable2DInterpPosEmb (InitContext *ctx, int height, int width, int dim)
26- : pos_emb(ggml::new_tensor_3d(ctx, ggml::type::GGML_TYPE_F32, dim, width, height ))
26+ : pos_emb(ggml::new_tensor_3d(ctx, ggml::type::GGML_TYPE_F32, dim, height, width ))
2727 {
2828 }
2929
@@ -36,7 +36,22 @@ namespace vit
3636 {
3737 CHATLLM_CHECK (ggml::get_dim (input, 3 ) == 1 );
3838
39- auto output = ggml::add (ctx, input, pos_emb);
39+ // input: ggml[dim, h, w]
40+
41+ ggml::tensor * interp = nullptr ;
42+ if ((ggml::get_dim (input, 1 ) == ggml::get_dim (pos_emb, 1 )) && (ggml::get_dim (input, 2 ) == ggml::get_dim (pos_emb, 2 )))
43+ {
44+ interp = pos_emb;
45+ }
46+ else
47+ {
48+ auto permuted = ggml::permute (ctx, pos_emb, 2 , 0 , 1 , 3 );
49+ interp = ggml::interpolate (ctx, permuted, ggml::InterpolateMode::Bicubic,
50+ ggml::get_dim (input, 1 ), ggml::get_dim (input, 2 ), ggml::get_dim (permuted, 2 ), ggml::get_dim (permuted, 3 ));
51+ interp = ggml::permute (ctx, interp, 1 , 2 , 0 , 3 );
52+ }
53+
54+ auto output = ggml::add (ctx, input, interp);
4055
4156 return output;
4257 }
@@ -106,28 +121,26 @@ namespace vit
106121 void before_forward (ComputeContext *ctx, const int n_past, const int qlen) override
107122 {
108123 const int len = grid_h * grid_w;
124+ std::vector<int > v_pos_h;
109125
110126 CHATLLM_CHECK (len <= max_length);
111127
112128 ggml::set_dim (pos, 0 , len);
113129 ggml::set_dim (pos_h, 0 , len);
114130
131+ v_pos_h.resize (len);
115132
116- for (int i = 0 ; i < grid_h ; i++)
133+ for (int i = 0 ; i < grid_w ; i++)
117134 {
118- for (int j = 0 ; j < grid_w; j++)
119- v_pos[i * grid_w + j] = j;
120- }
121-
122- Backend::write_tensor_data (pos, v_pos.data (), 0 , len * sizeof (v_pos[0 ]));
123-
124- for (int i = 0 ; i < grid_h; i++)
125- {
126- for (int j = 0 ; j < grid_w; j++)
127- v_pos[i * grid_w + j] = i;
135+ for (int j = 0 ; j < grid_h; j++)
136+ {
137+ v_pos [i * grid_h + j] = j;
138+ v_pos_h[i * grid_h + j] = i;
139+ }
128140 }
129141
130- Backend::write_tensor_data (pos_h, v_pos.data (), 0 , len * sizeof (v_pos[0 ]));
142+ Backend::write_tensor_data (pos, v_pos.data (), 0 , len * sizeof (v_pos[0 ]));
143+ Backend::write_tensor_data (pos_h, v_pos_h.data (), 0 , len * sizeof (v_pos[0 ]));
131144 }
132145
133146 ggml::tensor *apply_2d_rope (ComputeContext *ctx, ggml::tensor *hidden, int hidden_size, ggml::tensor *pos_w, ggml::tensor *pos_h) const
@@ -185,11 +198,13 @@ namespace vit
185198 linear_1 (ctx, hidden_size, hidden_size),
186199 linear_2 (ctx, hidden_size, lm_hidden_size)
187200 {
201+ memcpy (merge_param.merge_kernel_size , config.merge_kernel_size , sizeof (merge_param.merge_kernel_size ));
188202 }
189203
190204 ggml::tensor *forward (ComputeContext *ctx, ggml::tensor *image_features) override
191205 {
192- auto output = pre_norm.forward (ctx, image_features);
206+ auto output = merge_patch (ctx, image_features);
207+ output = pre_norm.forward (ctx, output);
193208 output = ggml::reshape_2d (ctx, output,
194209 hidden_size, ggml::get_dim (output, 1 ) / (hidden_size / ggml::get_dim (output, 0 )) * ggml::get_dim (output, 2 ) * ggml::get_dim (output, 3 ));
195210 output = linear_1.forward (ctx, output);
@@ -214,11 +229,18 @@ namespace vit
214229 linear_2.load (path + " linear_2." , loader);
215230 }
216231
232+ protected:
233+ ggml::tensor *merge_patch (ComputeContext *ctx, ggml::tensor *x)
234+ {
235+ auto reshaped_seq = ggml::merge_patch (ctx, x, &merge_param);
236+ return reshaped_seq;
237+ }
217238 public:
218239 const int hidden_size;
219240 LayerNorm pre_norm;
220241 Linear linear_1;
221242 Linear linear_2;
243+ ggml::merge_patch_param merge_param;
222244 };
223245
224246 class VisionTransformer : public Block
@@ -240,7 +262,6 @@ namespace vit
240262 layer->set_id (layer_id);
241263 layers.emplace_back (layer);
242264 }
243- memcpy (merge_param.merge_kernel_size , config.merge_kernel_size , sizeof (merge_param.merge_kernel_size ));
244265 }
245266
246267 int64_t get_param_num (bool effective_only) const override
@@ -269,17 +290,11 @@ namespace vit
269290 loaded = true ;
270291 }
271292
272- ggml::tensor *merge_patch (ComputeContext *ctx, ggml::tensor *x, int grid_h, int grid_w)
273- {
274- merge_param.grid_h = grid_h;
275- merge_param.grid_w = grid_w;
276-
277- auto reshaped_seq = ggml::merge_patch (ctx, x, &merge_param);
278- return reshaped_seq;
279- }
280-
281293 ggml::tensor *forward (ComputeContext *ctx, ggml::tensor *input, int grid_h, int grid_w)
282294 {
295+ multi_modal_projector.merge_param .grid_h = grid_h;
296+ multi_modal_projector.merge_param .grid_w = grid_w;
297+
283298 auto output = embeddings.forward (ctx, input, grid_h, grid_w);
284299
285300 for (size_t i = 0 ; i < layers.size (); i++)
@@ -289,7 +304,6 @@ namespace vit
289304 output = layers[i]->forward (ctx, output, 0 );
290305 }
291306 output = post_layernorm.forward (ctx, output);
292- output = merge_patch (ctx, output, grid_h, grid_w);
293307 output = multi_modal_projector.forward (ctx, output);
294308 return output;
295309 }
@@ -305,7 +319,6 @@ namespace vit
305319 MultiModalProjector multi_modal_projector;
306320 protected:
307321 bool loaded;
308- ggml::merge_patch_param merge_param;
309322 };
310323
311324 class VisualEmbeddingGeneration
@@ -703,9 +716,6 @@ namespace vl
703716 vision::MaxGridHeight param3 (512 );
704717 vision::MaxGridWidth param4 (512 );
705718
706- // TODO: cubic interpolation not ready yet. image size fixed.
707- vision::Resize resize (896 , 896 );
708-
709719 vision::image_load (piece.content .c_str (), pixels, w, h, patch_size, vision::PaddingMode::Black);
710720
711721 std::vector<float > scaled;
0 commit comments