We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dd52868 commit c770cf4Copy full SHA for c770cf4
tools/mtmd/clip.cpp
@@ -761,6 +761,14 @@ struct clip_graph {
761
ggml_set_name(window_mask, "window_mask");
762
ggml_set_input(window_mask);
763
764
+ // if flash attn is used, we need to pad the mask
765
+ if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
766
+ int padded_nrow = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD);
767
+ window_mask = ggml_pad(ctx0, window_mask,
768
+ 0, padded_nrow - window_mask->ne[0], 0, 0);
769
+ window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
770
+ }
771
+
772
// inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
773
GGML_ASSERT(batch_size == 1);
774
inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
0 commit comments