Skip to content

Commit 7b5d630

Browse files
committed
improve
1 parent c770cf4 commit 7b5d630

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tools/mtmd/clip.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -761,11 +761,12 @@ struct clip_graph {
761761
ggml_set_name(window_mask, "window_mask");
762762
ggml_set_input(window_mask);
763763

764-
// if flash attn is used, we need to pad the mask
764+
// if flash attn is used, we need to pad the mask and cast to f16
765765
if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
766-
int padded_nrow = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD);
767-
window_mask = ggml_pad(ctx0, window_mask,
768-
0, padded_nrow - window_mask->ne[0], 0, 0);
766+
int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1];
767+
if (n_pad > 0) {
768+
window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0);
769+
}
769770
window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
770771
}
771772

0 commit comments

Comments
 (0)