Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 35 additions & 33 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2259,11 +2259,37 @@ size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float);
}

static int clip_n_patches_by_img_dims(const struct clip_ctx * ctx, int x, int y) {
const auto & params = ctx->vision_model.hparams;

int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);

if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
n_patches /= 4;
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
if (ctx->minicpmv_version == 2) {
n_patches = 96;
}
else if (ctx->minicpmv_version == 3) {
n_patches = 64;
}
else if (ctx->minicpmv_version == 4) {
n_patches = 64;
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
int patch_size = params.patch_size * 2;
int x_patch = x / patch_size + (int)(x % patch_size > 0);
int y_patch = y / patch_size + (int)(y % patch_size > 0);
n_patches = x_patch * y_patch;
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
n_patches = 256;
}

return n_patches;
}

size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
clip_image_f32 img;
img.nx = img_w;
img.ny = img_h;
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
return clip_n_patches_by_img_dims(ctx, img_w, img_h) * clip_n_mmproj_embd(ctx) * sizeof(float);
}

int32_t clip_get_image_size(const struct clip_ctx * ctx) {
Expand Down Expand Up @@ -2294,39 +2320,15 @@ size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
}

int clip_n_patches(const struct clip_ctx * ctx) {
clip_image_f32 img;
img.nx = ctx->vision_model.hparams.image_size;
img.ny = ctx->vision_model.hparams.image_size;
return clip_n_patches_by_img(ctx, &img);
return clip_n_patches_by_img_dims(ctx, ctx->vision_model.hparams.image_size, ctx->vision_model.hparams.image_size);
}

int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
const auto & params = ctx->vision_model.hparams;

int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);

if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
n_patches /= 4;
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
if (ctx->minicpmv_version == 2) {
n_patches = 96;
}
else if (ctx->minicpmv_version == 3) {
n_patches = 64;
}
else if (ctx->minicpmv_version == 4) {
n_patches = 64;
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
n_patches = x_patch * y_patch;
} else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
n_patches = 256;
}
return clip_n_patches_by_img_dims(ctx, img->nx, img->ny);
}

return n_patches;
int clip_n_patches_by_img_u8(const struct clip_ctx * ctx, struct clip_image_u8 * img) {
return clip_n_patches_by_img_dims(ctx, img->nx, img->ny);
}

static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) {
Expand Down
7 changes: 4 additions & 3 deletions examples/llava/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);

CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_patches_by_img_u8 (const struct clip_ctx * ctx, struct clip_image_u8 * img);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these 2 API calls should be regrouped under prefix clip_img_*_get_n_output_tokens

Suggested change
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_patches_by_img_u8 (const struct clip_ctx * ctx, struct clip_image_u8 * img);
CLIP_API int clip_img_f32_get_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_img_u8_get_n_output_tokens (const struct clip_ctx * ctx, struct clip_image_u8 * img);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old clip_n_patches_by_img can be marked as deprecated (we can add a simple comment for now and will add proper __attribute__ in the future)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm, okay - I just deleted because it seemed like lots of breaking changes were already occurring amongst the refactor. But I can certainly re-add and mark deprecated if you'd prefer.

CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);

CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
Expand Down
Loading