Skip to content

Commit 88032f4

Browse files
committed
window partitioning using standard ggml ops
1 parent 89afda8 commit 88032f4

File tree

1 file changed

+46
-4
lines changed

1 file changed

+46
-4
lines changed

tools/mtmd/clip.cpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,8 @@ struct clip_graph {
690690
if (hparams.is_global_attn(il) == false) {
691691
// local attention layer - apply window partition
692692
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172
693-
cur = ggml_win_part(ctx0, cur, 14);
693+
//cur = ggml_win_part(ctx0, cur, 14);
694+
cur = window_partition(ctx0, cur, 14);
694695
}
695696

696697
const int64_t W = cur->ne[1];
@@ -762,7 +763,7 @@ struct clip_graph {
762763

763764
if (hparams.is_global_attn(il) == false) {
764765
// local attention layer - reverse window partition
765-
cur = ggml_win_unpart(ctx0, cur, w0, h0, 14);
766+
cur = window_unpartition(ctx0, cur, w0, h0, 14);
766767
}
767768

768769
// re-add the layer input, e.g., residual
@@ -865,9 +866,10 @@ struct clip_graph {
865866

866867
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
867868
ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h)
868-
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim)
869-
ggml_tensor * nl = ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3);
869+
t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
870+
ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
870871
nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows
872+
nl = ggml_cont(ctx0, nl);
871873

872874

873875
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
@@ -2464,6 +2466,46 @@ struct clip_graph {
24642466
return inpL;
24652467
}
24662468

2469+
static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) {
2470+
auto [c, w, h, b] = x->ne;
2471+
// same as
2472+
// x = ggml_win_part(m, x, window);
2473+
// x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]);
2474+
2475+
int64_t px = (window - w % window) % window;
2476+
int64_t py = (window - h % window) % window;
2477+
int64_t npw = (w + px) / window;
2478+
int64_t nph = (h + py) / window;
2479+
2480+
if (px > 0 || py > 0) {
2481+
x = ggml_pad(ctx, x, 0, int(px), int(py), 0);
2482+
}
2483+
x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b);
2484+
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));
2485+
x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b);
2486+
return x;
2487+
}
2488+
2489+
static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) {
2490+
int64_t c = x->ne[0];
2491+
// same as
2492+
// x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]);
2493+
// x = ggml_win_unpart(m, x, w, h, window);
2494+
2495+
int64_t px = (window - w % window) % window;
2496+
int64_t py = (window - h % window) % window;
2497+
int64_t npw = (w + px) / window;
2498+
int64_t nph = (h + py) / window;
2499+
2500+
int64_t b = x->ne[3] / (npw * nph);
2501+
x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b);
2502+
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
2503+
x = ggml_reshape_4d(m, x, c, w + px, h + py, b);
2504+
x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
2505+
x = ggml_cont(m, x);
2506+
return x;
2507+
}
2508+
24672509
// build the input after conv2d (inp_raw --> patches)
24682510
// returns tensor with shape [n_embd, n_patches]
24692511
ggml_tensor * build_enc_inp(ggml_tensor * inp_raw,

0 commit comments

Comments
 (0)