Skip to content

Commit c276a91

Browse files
committed
llava: n_patches for clip_image_u8
1 parent 75afa0a commit c276a91

File tree

3 files changed

+44
-41
lines changed

3 files changed

+44
-41
lines changed

examples/llava/clip.cpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,11 +2259,37 @@ size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
22592259
return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float);
22602260
}
22612261

2262+
static int clip_n_patches_by_img_dims(const struct clip_ctx * ctx, int x, int y) {
2263+
const auto & params = ctx->vision_model.hparams;
2264+
2265+
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
2266+
2267+
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
2268+
n_patches /= 4;
2269+
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
2270+
if (ctx->minicpmv_version == 2) {
2271+
n_patches = 96;
2272+
}
2273+
else if (ctx->minicpmv_version == 3) {
2274+
n_patches = 64;
2275+
}
2276+
else if (ctx->minicpmv_version == 4) {
2277+
n_patches = 64;
2278+
}
2279+
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
2280+
int patch_size = params.patch_size * 2;
2281+
int x_patch = x / patch_size + (int)(x % patch_size > 0);
2282+
int y_patch = y / patch_size + (int)(y % patch_size > 0);
2283+
n_patches = x_patch * y_patch;
2284+
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
2285+
n_patches = 256;
2286+
}
2287+
2288+
return n_patches;
2289+
}
2290+
22622291
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
2263-
clip_image_f32 img;
2264-
img.nx = img_w;
2265-
img.ny = img_h;
2266-
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
2292+
return clip_n_patches_by_img_dims(ctx, img_w, img_h) * clip_n_mmproj_embd(ctx) * sizeof(float);
22672293
}
22682294

22692295
int32_t clip_get_image_size(const struct clip_ctx * ctx) {
@@ -2294,39 +2320,15 @@ size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
22942320
}
22952321

22962322
int clip_n_patches(const struct clip_ctx * ctx) {
2297-
clip_image_f32 img;
2298-
img.nx = ctx->vision_model.hparams.image_size;
2299-
img.ny = ctx->vision_model.hparams.image_size;
2300-
return clip_n_patches_by_img(ctx, &img);
2323+
return clip_n_patches_by_img_dims(ctx, ctx->vision_model.hparams.image_size, ctx->vision_model.hparams.image_size);
23012324
}
23022325

2303-
int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
2304-
const auto & params = ctx->vision_model.hparams;
2305-
2306-
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
2307-
2308-
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
2309-
n_patches /= 4;
2310-
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
2311-
if (ctx->minicpmv_version == 2) {
2312-
n_patches = 96;
2313-
}
2314-
else if (ctx->minicpmv_version == 3) {
2315-
n_patches = 64;
2316-
}
2317-
else if (ctx->minicpmv_version == 4) {
2318-
n_patches = 64;
2319-
}
2320-
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
2321-
int patch_size = params.patch_size * 2;
2322-
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
2323-
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
2324-
n_patches = x_patch * y_patch;
2325-
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
2326-
n_patches = 256;
2327-
}
2326+
int clip_n_patches_by_img_f32(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
2327+
return clip_n_patches_by_img_dims(ctx, img->nx, img->ny);
2328+
}
23282329

2329-
return n_patches;
2330+
int clip_n_patches_by_img_u8(const struct clip_ctx * ctx, struct clip_image_u8 * img) {
2331+
return clip_n_patches_by_img_dims(ctx, img->nx, img->ny);
23302332
}
23312333

23322334
static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) {

examples/llava/clip.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
5858
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
5959
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
6060

61-
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
62-
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
63-
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
61+
CLIP_API int clip_n_patches(const struct clip_ctx * ctx);
62+
CLIP_API int clip_n_patches_by_img_f32(const struct clip_ctx * ctx, struct clip_image_f32 * img);
63+
CLIP_API int clip_n_patches_by_img_u8(const struct clip_ctx * ctx, struct clip_image_u8 * img);
64+
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
6465

6566
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
6667
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
@@ -74,10 +75,10 @@ CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used
7475
// nx, ny are the output image dimensions
7576
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
7677

77-
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
78-
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
78+
CLIP_API void clip_image_size_free(struct clip_image_size * img_size);
79+
CLIP_API void clip_image_u8_free(struct clip_image_u8 * img);
7980
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
80-
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
81+
CLIP_API void clip_image_u8_batch_free(struct clip_image_u8_batch * batch);
8182
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
8283

8384
// use for accessing underlay data of clip_image_f32_batch

examples/llava/llava.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
313313
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
314314
image_embd_v[i],
315315
clip_embd_nbytes_by_img(ctx_clip, nx, ny));
316-
n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res);
316+
n_img_pos_out += clip_n_patches_by_img_f32(ctx_clip, img_res);
317317
}
318318
*n_img_pos = n_img_pos_out;
319319
for (size_t i = 0; i < image_embd_v.size(); i++) {

0 commit comments

Comments
 (0)