Skip to content

Commit 1ea378d

Browse files
committed
gemma 3: implement _pan & scan_
1 parent 49aabdb commit 1ea378d

File tree

4 files changed

+198
-38
lines changed

4 files changed

+198
-38
lines changed

docs/models.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,17 @@ Please use `--format completion` for these models.
300300
* [x] v3: [Instruct-4B](https://huggingface.co/google/gemma-3-4b-it/tree/dbd91bbaf64a0e591f4340ce8b66fd1dba9ab6bd), [Instruct-12B](https://huggingface.co/google/gemma-3-12b-it/tree/7553b6f39c33dc229bfbfe3831f7bcdbb6b738c7), [Instruct-27B](https://huggingface.co/google/gemma-3-27b-it/tree/dfb98f29ff907e391ceed2be3834ca071ea260f1)
301301
* [x] MedGemma: [Instruct-4B](https://huggingface.co/google/medgemma-4b-it/commit/698f7911b8e0569ff4ebac5d5552f02a9553063c)
302302

303-
Note: Only download `tokenizer.model` and DO NOT download `tokenizer.json` when converting.
303+
Note: Only download `tokenizer.model` and DO NOT download `tokenizer.json` when converting. Use `--set do-pan-and-scan 1` to enable _Pan and Scan_.
304+
304305

305306
* Kimi (`KimiVLForConditionalGeneration`)
306307
* [x] VL: [A3B-Instruct](https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/tree/7a3c132a7b0f1f1677f5a72f258bd3afded7d357), [A3B-Thinking](https://huggingface.co/moonshotai/Kimi-VL-A3B-Thinking/commit/16681d8ac24e505088698e4e34ea494dd6e24400)
307308

308309
* SmolVLM2 (`SmolVLMForConditionalGeneration`)
309310
* [x] [2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct/tree/482adb537c021c86670beed01cd58990d01e72e4)
310311

312+
Note: Use `--set do-split 1` to enable _Split_.
313+
311314
## RAG Models
312315

313316
* Text Embedding (`XLMRobertaModel`)

models/gemma.cpp

Lines changed: 105 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ struct Config
362362
float image_mean[3];
363363
float image_std[3];
364364
bool vision_use_head;
365+
366+
int max_num_crops;
367+
int min_crop_size;
368+
float min_ratio_to_activate;
365369
};
366370

367371
class PatchEmbedding : public Block
@@ -377,7 +381,7 @@ class PatchEmbedding : public Block
377381
ggml::tensor *forward(ComputeContext *ctx, ggml::tensor *input) override
378382
{
379383
auto embedding = patch_embedding.forward(ctx, input);
380-
embedding = ggml::reshape_2d(ctx, embedding, ggml::get_dim(embedding, 0) * ggml::get_dim(embedding, 1), ggml::get_dim(embedding, 2));
384+
embedding = ggml::reshape_3d(ctx, embedding, ggml::get_dim(embedding, 0) * ggml::get_dim(embedding, 1), ggml::get_dim(embedding, 2), ggml::get_dim(embedding, 3));
381385
embedding = ggml::transpose(ctx, embedding);
382386
embedding = ggml::cont(ctx, embedding);
383387
embedding = ggml::add(ctx, embedding, position_embedding);
@@ -589,6 +593,11 @@ class VisualEmbeddingGeneration
589593
vis_config.image_std[i] = 0.5f;
590594
}
591595

596+
// ref: https://github.com/huggingface/transformers/blob/9487765f07ef4e5500d6ec21cad99aed4a037a3d/src/transformers/models/gemma3/processing_gemma3.py#L36
597+
vis_config.min_crop_size = 256;
598+
vis_config.max_num_crops = 4;
599+
vis_config.min_ratio_to_activate = 1.2f;
600+
592601
const size_t tensor_ovhd = ggml_tensor_overhead();
593602
const size_t num_tensors = 7 + vis_config.num_hidden_layers * 17;
594603
const size_t ctx_size = num_tensors * tensor_ovhd;
@@ -608,18 +617,21 @@ class VisualEmbeddingGeneration
608617
if ((vis_model.get() == nullptr) || (tok->media_emb.size() < 1)) return;
609618
if (!vis_model->is_loaded()) return;
610619

611-
run_model(gen_config, tok, dtype, buf);
620+
for (auto &image : tok->media_emb)
621+
{
622+
run_model(gen_config, tok, dtype, image, buf);
623+
}
612624
}
613625

614626
protected:
615-
bool run_model(const GenerationConfig &gen_config, BaseTokenizer *tok, ggml::type dtype, std::vector<uint8_t> &buf)
627+
bool run_model(const GenerationConfig &gen_config, BaseTokenizer *tok, ggml::type dtype, const BaseTokenizer::MediaAsEmbeddingVector &image, std::vector<uint8_t> &buf)
616628
{
617629
ForwardContext ctx(&backend_context);
618630
ctx.gctx = GGMLContext({.mem_size = backend_context.buf_compute_meta.size(), .mem_buffer = backend_context.buf_compute_meta.data(), .no_alloc = true});
619631
ctx.gf = ggml::new_graph_custom(&ctx, GRAPH_SIZE, false);
620632

621633
ctx.move_to_layer(LayerAllocatorManager::MiscLayer::Prolog);
622-
ggml::tensor *media_emb = ggml::new_tensor_4d(&ctx, ggml::type::GGML_TYPE_F32, vis_config.image_size, vis_config.image_size, 3, tok->media_emb.size());
634+
ggml::tensor *media_emb = ggml::new_tensor_3d(&ctx, ggml::type::GGML_TYPE_F32, vis_config.image_size, vis_config.image_size, 3);
623635

624636
dbg_ctx = &ctx;
625637

@@ -642,18 +654,14 @@ class VisualEmbeddingGeneration
642654
exit(-1);
643655
}
644656

645-
size_t offset = 0;
646-
for (auto &image : tok->media_emb)
647-
{
648-
size_t size = image.data.size() * sizeof(image.data[0]);
649-
Backend::write_tensor_data(media_emb, image.data.data(), offset, size);
650-
offset += size;
651-
}
657+
Backend::write_tensor_data(media_emb, image.data.data(), 0, image.data.size() * sizeof(image.data[0]));
652658

653659
ctx.compute();
654660

655-
buf.resize(ggml::nbytes(r));
656-
Backend::read_tensor_data(r, buf.data());
661+
size_t offset = buf.size();
662+
buf.resize(offset + ggml::nbytes(r));
663+
Backend::read_tensor_data(r, buf.data() + offset);
664+
657665
ctx.reset();
658666

659667
return true;
@@ -691,6 +699,8 @@ class ChatHistoryEncoder : public v1::ChatHistoryEncoder
691699
{
692700
public:
693701
void append_user(int round_idx, const Content &user, std::vector<int> &ids) const override;
702+
protected:
703+
bool append_image(const vision::image_pixels_t pixels, const int w, const int h, std::vector<int> &ids) const;
694704
public:
695705
const siglip::Config *vis_config = nullptr;
696706
int MAX_PATCH_NUM = 0;
@@ -717,6 +727,7 @@ class Tokenizer : public v1::Tokenizer
717727
public:
718728
int boi_token_id;
719729
int eoi_token_id;
730+
bool do_pan_and_scan = false;
720731
};
721732

722733
template <int sliding_window_len> class Gemma3SWASelfAttention : public QKNormedAttention<RMSNorm, SlidingWindowAttentionImpl<sliding_window_len>>
@@ -860,6 +871,17 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
860871
return r;
861872
}
862873

874+
void set_additional_args(const std::map<std::string, std::string> &args) override
875+
{
876+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
877+
auto it = args.find("do_pan_and_scan");
878+
if (it == args.end()) it = args.find("do-pan-and-scan");
879+
if (it != args.end())
880+
{
881+
tok->do_pan_and_scan = it->second != "0";
882+
}
883+
}
884+
863885
void before_generate(const GenerationConfig &gen_config) override
864886
{
865887
std::vector<uint8_t> buf;
@@ -885,6 +907,41 @@ class ConditionalGeneration : public BaseModelForConditionalGeneration
885907
siglip::VisualEmbeddingGeneration visual;
886908
};
887909

910+
bool ChatHistoryEncoder::append_image(const vision::image_pixels_t pixels, const int w, const int h, std::vector<int> &ids) const
911+
{
912+
const int patch_size = vis_config->patch_size;
913+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
914+
std::vector<float> scaled;
915+
916+
vision::image_rescale(pixels, scaled);
917+
918+
vision::image_normalize(scaled, vis_config->image_mean, vis_config->image_std);
919+
920+
tok->media_emb.push_back({.grid_width = w / patch_size, .grid_height = h / patch_size, .patch_size = patch_size, .data = {}});
921+
922+
auto &image = tok->media_emb.back();
923+
924+
vision::image_arrange(scaled, w, patch_size, image.data, vision::PatchesFormat::ChannelsRGB_PixelsLeftRightDown);
925+
926+
image.emb_vec_number = vis_config->mm_tokens_per_image;
927+
928+
const int total_patches = tok->get_image_total_emb_vectors();
929+
CHATLLM_CHECK(total_patches <= MAX_PATCH_NUM) << "too many image patches!";
930+
931+
ids.push_back(tok->nl_token_id);
932+
ids.push_back(tok->nl_token_id);
933+
ids.push_back(tok->boi_token_id);
934+
int id = total_patches - image.emb_vec_number + tok->vocab_size;
935+
for (int j = 0; j < image.emb_vec_number; j++)
936+
{
937+
ids.push_back(id++);
938+
}
939+
ids.push_back(tok->eoi_token_id);
940+
ids.push_back(tok->nl_token_id);
941+
ids.push_back(tok->nl_token_id);
942+
943+
return true;
944+
}
888945

889946
void ChatHistoryEncoder::append_user(int round_idx, const Content &user, std::vector<int> &ids) const
890947
{
@@ -902,40 +959,51 @@ void ChatHistoryEncoder::append_user(int round_idx, const Content &user, std::ve
902959
{
903960
CHATLLM_CHECK(vit_loaded) << "Vision model not loaded";
904961

905-
vision::Resize resize(vis_config->image_size, vis_config->image_size);
906-
907-
int w, h;
908-
std::vector<uint8_t> pixels;
909-
const int patch_size = vis_config->patch_size;
910-
vision::image_load(piece.content.c_str(), pixels, w, h, patch_size);
911-
912-
if (w <= 0) continue;
962+
if (tok->do_pan_and_scan)
963+
{
964+
std::vector<vision::image_pixels_t> crops;
913965

914-
std::vector<float> scaled;
915-
vision::image_rescale(pixels, scaled);
966+
int splits_cols_num = 0;
967+
vision::PanScanDir dir = vision::PanScanDir::Horizontal;
916968

917-
vision::image_normalize(scaled, vis_config->image_mean, vis_config->image_std);
969+
vision::image_load_pan_and_scan(piece.content.c_str(),
970+
crops, tok->do_pan_and_scan,
971+
vis_config->min_crop_size, vis_config->max_num_crops, vis_config->min_ratio_to_activate,
972+
vis_config->image_size, vis_config->image_size,
973+
dir);
918974

919-
tok->media_emb.push_back({.grid_width = w / patch_size, .grid_height = h / patch_size, .patch_size = patch_size, .data = {}});
975+
printf("crops: %d\n", (int)crops.size());
976+
if (crops.size() < 1) continue;
920977

921-
auto &image = tok->media_emb.back();
978+
if (crops.size() == 1)
979+
{
980+
append_image(crops[0], vis_config->image_size, vis_config->image_size, ids);
981+
continue;
982+
}
922983

923-
vision::image_arrange(scaled, w, patch_size, image.data, vision::PatchesFormat::ChannelsRGB_PixelsLeftRightDown);
984+
tok->encode("Here is the original image ", ids, false, false);
985+
append_image(crops[0], vis_config->image_size, vis_config->image_size, ids);
986+
tok->encode(" and here are some crops to help you see better", ids, false, false);
924987

925-
image.emb_vec_number = vis_config->mm_tokens_per_image;
988+
for (size_t i = 1; i < crops.size(); i++)
989+
{
990+
tok->encode(" ", ids, false, false);
991+
append_image(crops[i], vis_config->image_size, vis_config->image_size, ids);
992+
}
993+
}
994+
else
995+
{
996+
vision::Resize resize(vis_config->image_size, vis_config->image_size);
926997

927-
CHATLLM_CHECK(image.emb_vec_number) << "too many image patches!";
998+
int w, h;
999+
std::vector<uint8_t> pixels;
1000+
const int patch_size = vis_config->patch_size;
1001+
vision::image_load(piece.content.c_str(), pixels, w, h, patch_size);
9281002

929-
const int total_patches = tok->get_image_total_emb_vectors();
930-
CHATLLM_CHECK(total_patches <= MAX_PATCH_NUM) << "too many image patches!";
1003+
if (w <= 0) continue;
9311004

932-
ids.push_back(tok->boi_token_id);
933-
int id = total_patches - image.emb_vec_number + tok->vocab_size;
934-
for (int j = 0; j < image.emb_vec_number; j++)
935-
{
936-
ids.push_back(id++);
1005+
append_image(pixels, w, h, ids);
9371006
}
938-
ids.push_back(tok->eoi_token_id);
9391007
}
9401008
else
9411009
{

src/vision_process.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,84 @@ namespace vision
247247
run_cmd(oss, image);
248248
}
249249

250+
void image_load_pan_and_scan(const char *fn, std::vector<image_pixels_t> &crops, bool do_pas,
251+
const int min_crop_size, const int max_num_crops, float min_ratio_to_activate,
252+
const int crop_width, const int crop_height,
253+
PanScanDir &dir)
254+
{
255+
crops.clear();
256+
257+
int width = -1;
258+
int height = -1;
259+
image_dimension(fn, width, height);
260+
if (width <= 0) return;
261+
262+
// whole image
263+
{
264+
std::ostringstream oss;
265+
oss << "magick -depth 8 \"" << std::string(fn) << "\"";
266+
oss << " -resize " << crop_width << "x" << crop_height << "!";
267+
crops.emplace_back(image_pixels_t());
268+
auto &image = crops.back();
269+
run_cmd(oss, image);
270+
}
271+
272+
int num_crops_w = 1;
273+
int num_crops_h = 1;
274+
275+
if (width >= height)
276+
{
277+
auto ratio = (float)width / height;
278+
if (ratio < min_ratio_to_activate)
279+
return;
280+
dir = PanScanDir::Horizontal;
281+
282+
num_crops_w = int(width / height + 0.5);
283+
num_crops_w = std::min((width / min_crop_size), num_crops_w);
284+
285+
num_crops_w = std::max(2, num_crops_w);
286+
num_crops_w = std::min(max_num_crops, num_crops_w);
287+
}
288+
else
289+
{
290+
auto ratio = (float)height / width;
291+
if (ratio < min_ratio_to_activate)
292+
return;
293+
dir = PanScanDir::Vertical;
294+
295+
num_crops_h = int(height / width + 0.5);
296+
num_crops_h = std::min((height / min_crop_size), num_crops_h);
297+
298+
num_crops_h = std::max(2, num_crops_h);
299+
num_crops_h = std::min(max_num_crops, num_crops_h);
300+
}
301+
302+
const int crop_size_w = (width + (num_crops_w - 1) / num_crops_w);
303+
const int crop_size_h = (height + (num_crops_h - 1) / num_crops_h);
304+
305+
if (std::min(crop_size_w, crop_size_h) < min_crop_size)
306+
return;
307+
308+
for (int r = 0; r < num_crops_h; r++)
309+
{
310+
const int start_y = r * crop_size_h;
311+
312+
for (int c = 0; c < num_crops_w; c++)
313+
{
314+
const int start_x = c * crop_size_w;
315+
316+
std::ostringstream oss;
317+
oss << "magick -depth 8 \"" << std::string(fn) << "\"";
318+
oss << " -crop " << crop_size_w << "x" << crop_size_h << "+" << start_x << "+" << start_y;
319+
oss << " -resize " << crop_width << "x" << crop_height << "!";
320+
321+
crops.emplace_back(image_pixels_t());
322+
auto &image = crops.back();
323+
run_cmd(oss, image);
324+
}
325+
}
326+
}
327+
250328
void image_load(const char *fn, std::vector<uint8_t> &rgb_pixels, int &width, int &height, int patch_size, PaddingMode pad)
251329
{
252330
// magick -depth 8 demo.jpeg -resize 100x100 rgb:"aaa.raw"

src/vision_process.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,22 @@ namespace vision
9393
static bool PreScale(int &width, int &height);
9494
};
9595

96+
enum PanScanDir
97+
{
98+
Horizontal,
99+
Vertical,
100+
};
101+
96102
typedef std::vector<uint8_t> image_pixels_t; // natural sequence of RGB pixels
97103

98104
void image_dimension(const char *fn, int &width, int &height);
99105
void image_load(const char *fn, std::vector<uint8_t> &rgb_pixels, int &width, int &height, int patch_size, PaddingMode pad = PaddingMode::No);
100106
void image_load_split(const char *fn, std::vector<image_pixels_t> &splits, bool do_split, const int split_width, const int split_height, int &splits_cols_num, int &splits_rows_num); // splits are in natural order
107+
void image_load_pan_and_scan(const char *fn, std::vector<image_pixels_t> &crops, bool do_pas,
108+
const int min_crop_size, const int max_num_crops, float min_ratio_to_activate,
109+
const int crop_width, const int crop_height,
110+
PanScanDir &dir);
111+
101112
void image_rescale(const std::vector<uint8_t> &rgb_pixels, std::vector<float> &scaled_rgb_pixels, float scale_factor = 1/255.0f);
102113
void image_normalize(std::vector<float> &rgb_pixels, const float *mean, const float *std_d);
103114

0 commit comments

Comments
 (0)