Skip to content

Commit 000d1d9

Browse files
committed
fix token calculation
1 parent 4274417 commit 000d1d9

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

tools/mtmd/clip.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2371,16 +2371,16 @@ struct clip_graph {
23712371

23722372
// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
23732373
// support dynamic resolution
2374-
ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int kernel_size) {
2375-
GGML_ASSERT(kernel_size > 1);
2374+
ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
2375+
GGML_ASSERT(scale_factor > 1);
23762376

23772377
const int n_embd = cur->ne[0];
23782378
int width = img.nx / patch_size;
23792379
int height = img.ny / patch_size;
23802380

23812381
// pad width and height to factor
2382-
const int64_t pad_width = CLIP_ALIGN(width, kernel_size) - width;
2383-
const int64_t pad_height = CLIP_ALIGN(height, kernel_size) - height;
2382+
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
2383+
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
23842384
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
23852385
if (pad_width || pad_height) {
23862386
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
@@ -2389,11 +2389,11 @@ struct clip_graph {
23892389
}
23902390

23912391
// unshuffle h
2392-
cur = ggml_reshape_3d(ctx0, cur, n_embd * kernel_size, width / kernel_size, height);
2392+
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
23932393
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
23942394

23952395
// unshuffle w
2396-
cur = ggml_cont_3d(ctx0, cur, n_embd * kernel_size * kernel_size, height / kernel_size, width / kernel_size);
2396+
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
23972397
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
23982398

23992399
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
@@ -4351,7 +4351,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
43514351
case PROJECTOR_TYPE_INTERNVL:
43524352
case PROJECTOR_TYPE_LLAMA4:
43534353
{
4354-
n_patches /= ctx->model.hparams.proj_scale_factor;
4354+
// both X and Y are downscaled by the scale factor
4355+
int scale_factor = ctx->model.hparams.proj_scale_factor;
4356+
n_patches /= (scale_factor * scale_factor);
43554357
} break;
43564358
case PROJECTOR_TYPE_LFM2:
43574359
case PROJECTOR_TYPE_KIMIVL:

0 commit comments

Comments
 (0)