99
1010#include < float.h>
1111#include < algorithm>
12+ #include < vector>
1213
1314// ggml_compute_forward_dup
1415
@@ -7915,7 +7916,8 @@ static void ggml_compute_forward_sparsek_attn_f32(
79157916 const struct ggml_compute_params * params,
79167917 struct ggml_tensor * dst) {
79177918
7918- if (params->ith != 0 ) return ; // main thread only
7919+ // Single-threaded baseline version (expand later for parallelism)
7920+ if (params->ith != 0 ) return ;
79197921
79207922 const struct ggml_tensor * Q = dst->src [0 ];
79217923 const struct ggml_tensor * K = dst->src [1 ];
@@ -7925,56 +7927,87 @@ static void ggml_compute_forward_sparsek_attn_f32(
79257927 GGML_ASSERT (Q->type == GGML_TYPE_F32);
79267928 GGML_ASSERT (K->type == GGML_TYPE_F32);
79277929 GGML_ASSERT (V->type == GGML_TYPE_F32);
7930+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
79287931
79297932 const int32_t k_top = ggml_get_op_params_i32 (dst, 0 );
79307933 const int32_t win_local = ggml_get_op_params_i32 (dst, 1 );
79317934 const int32_t stride_glb = ggml_get_op_params_i32 (dst, 2 );
7935+ GGML_UNUSED (win_local);
7936+ GGML_UNUSED (stride_glb);
79327937
7933- const int64_t D = Q->ne [0 ]; // embedding dim
7934- const int64_t T = Q->ne [1 ]; // sequence length
7938+ // Tensor dimensions according to GGML layout: ne[0]=d, ne[1]=seq, ne[2]=head, ne[3]=batch
7939+ const int64_t D = Q->ne [0 ];
7940+ const int64_t T = Q->ne [1 ];
7941+ const int64_t H = Q->ne [2 ];
7942+ const int64_t B = Q->ne [3 ];
79357943
7936- const float * q = (const float *) Q->data ;
7937- const float * k = (const float *) K->data ;
7938- const float * v = (const float *) V->data ;
7939- float * out = (float *) dst->data ;
7944+ // Temporary buffer for attention scores for one query row
7945+ std::vector<float > attn_row (T, 0 .0f );
79407946
7941-
7942- for (int64_t i = 0 ; i < T; ++i) {
7943- for (int64_t j = 0 ; j < T; ++j) {
7944- float dot = 0 .0f ;
7945- for (int64_t d = 0 ; d < D; ++d)
7946- dot += q[i*D + d] * k[j*D + d];
7947- out[i*T + j] = dot / sqrtf ((float ) D);
7948- }
7949- }
7947+ const float scale = 1 .0f / sqrtf ((float ) D);
79507948
7951- for (int64_t i = 0 ; i < T; ++i) {
7952- float * row = &out[i*T];
7953- for (int64_t j = 0 ; j < T; ++j)
7954- if (row[j] < row[k_top]) row[j] = -INFINITY;
7955- }
7949+ // Loops over batch, head, and query token
7950+ for (int64_t b = 0 ; b < B; ++b) {
7951+ for (int64_t h = 0 ; h < H; ++h) {
7952+ for (int64_t iq = 0 ; iq < T; ++iq) {
79567953
7957- for (int64_t i = 0 ; i < T; ++i) {
7958- float maxv = -INFINITY;
7959- for (int64_t j = 0 ; j < T; ++j)
7960- if (out[i*T + j] > maxv) maxv = out[i*T + j];
7961- float sum = 0 .0f ;
7962- for (int64_t j = 0 ; j < T; ++j) {
7963- out[i*T + j] = expf (out[i*T + j] - maxv);
7964- sum += out[i*T + j];
7965- }
7966- for (int64_t j = 0 ; j < T; ++j)
7967- out[i*T + j] /= sum;
7968- }
7954+ // (1) Compute dot products Q·K within same (b,h)
7955+ const char * qbase = (const char *) Q->data + b*Q->nb [3 ] + h*Q->nb [2 ] + iq*Q->nb [1 ];
7956+ const float * qv = (const float *) qbase;
79697957
7958+ for (int64_t j = 0 ; j < T; ++j) {
7959+ const char * kbase = (const char *) K->data + b*K->nb [3 ] + h*K->nb [2 ] + j*K->nb [1 ];
7960+ const float * kv = (const float *) kbase;
79707961
7971- float * result = (float *) dst->data ;
7972- for (int64_t i = 0 ; i < T; ++i) {
7973- for (int64_t d = 0 ; d < D; ++d) {
7974- float sum = 0 .0f ;
7975- for (int64_t j = 0 ; j < T; ++j)
7976- sum += out[i*T + j] * v[j*D + d];
7977- result[i*D + d] = sum;
7962+ float dot = 0 .0f ;
7963+ for (int64_t d = 0 ; d < D; ++d) {
7964+ dot += qv[d] * kv[d];
7965+ }
7966+ attn_row[j] = dot * scale;
7967+ }
7968+
7969+ // (2) Select top-k threshold using nth_element
7970+ const int kk = std::max<int >(1 , std::min<int >((int )T, k_top));
7971+ std::vector<float > tmp (attn_row.begin (), attn_row.end ());
7972+ std::nth_element (tmp.begin (), tmp.begin () + (kk - 1 ), tmp.end (), std::greater<float >());
7973+ const float thr = tmp[kk - 1 ];
7974+
7975+ for (int64_t j = 0 ; j < T; ++j) {
7976+ if (attn_row[j] < thr) attn_row[j] = -INFINITY;
7977+ }
7978+
7979+ // (3) Numerically stable softmax on the masked row
7980+ float maxv = -INFINITY;
7981+ for (int64_t j = 0 ; j < T; ++j) {
7982+ maxv = std::max (maxv, attn_row[j]);
7983+ }
7984+ float sum = 0 .0f ;
7985+ for (int64_t j = 0 ; j < T; ++j) {
7986+ float v = attn_row[j] - maxv;
7987+ float e = expf (v);
7988+ attn_row[j] = e;
7989+ sum += e;
7990+ }
7991+ const float inv_sum = sum > 0 .0f ? 1 .0f / sum : 0 .0f ;
7992+ for (int64_t j = 0 ; j < T; ++j) {
7993+ attn_row[j] *= inv_sum;
7994+ }
7995+
7996+ // (4) Compute output = A·V (weighted sum)
7997+ float * y = (float *) ((char *) dst->data + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
7998+
7999+ for (int64_t d = 0 ; d < D; ++d) {
8000+ float acc = 0 .0f ;
8001+ for (int64_t j = 0 ; j < T; ++j) {
8002+ const float aij = attn_row[j];
8003+ if (aij == 0 .0f ) continue ; // skip masked entries
8004+ const char * vbase = (const char *) V->data + b*V->nb [3 ] + h*V->nb [2 ] + j*V->nb [1 ];
8005+ const float * vv = (const float *) vbase;
8006+ acc += aij * vv[d];
8007+ }
8008+ y[d] = acc;
8009+ }
8010+ }
79788011 }
79798012 }
79808013
@@ -7985,7 +8018,13 @@ static void ggml_compute_forward_sparsek_attn_f32(
79858018void ggml_compute_forward_sparsek_attn (
79868019 const struct ggml_compute_params * params,
79878020 struct ggml_tensor * dst) {
7988- ggml_compute_forward_sparsek_attn_f32 (params, dst);
8021+ switch (dst->type ) {
8022+ case GGML_TYPE_F32:
8023+ ggml_compute_forward_sparsek_attn_f32 (params, dst);
8024+ break ;
8025+ default :
8026+ GGML_ASSERT (false && " sparsek_attn: unsupported dst type" );
8027+ }
79898028}
79908029
79918030
0 commit comments