Skip to content

Commit 84b35d2

Browse files
committed
use ggml_soft_max_ext
1 parent 376f80a commit 84b35d2

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

examples/llava/clip.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,7 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
465465
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
466466

467467
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
468-
KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
469-
KQ = ggml_soft_max_inplace(ctx0, KQ);
468+
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
470469

471470
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
472471
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
@@ -721,7 +720,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
721720
ctx0, Q, positions, nullptr,
722721
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
723722
}
724-
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
725723
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
726724
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
727725

@@ -745,7 +743,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
745743
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
746744

747745
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
748-
KQ = ggml_soft_max_inplace(ctx0, KQ);
746+
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
749747
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
750748
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
751749
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@@ -1033,7 +1031,6 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
10331031
}
10341032

10351033
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
1036-
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
10371034
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
10381035
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
10391036
// permute
@@ -1047,7 +1044,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
10471044
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
10481045
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
10491046
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1050-
KQ = ggml_soft_max_inplace(ctx0, KQ);
1047+
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
10511048
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
10521049
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
10531050
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

0 commit comments

Comments
 (0)