Skip to content

Commit fd96766

Browse files
committed
Integrate SparseK Attention via FlashAttention extension (CPU backend) [yael-works]
1 parent a063c64 commit fd96766

File tree

5 files changed

+121
-102
lines changed

5 files changed

+121
-102
lines changed

ggml/include/ggml.h

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
#define GGML_MAX_PARAMS 2048
220220
#define GGML_MAX_SRC 10
221221
#define GGML_MAX_N_THREADS 512
222-
#define GGML_MAX_OP_PARAMS 64
222+
#define GGML_MAX_OP_PARAMS 128
223223

224224
#ifndef GGML_MAX_NAME
225225
# define GGML_MAX_NAME 64
@@ -530,7 +530,6 @@ extern "C" {
530530
GGML_OP_TIMESTEP_EMBEDDING,
531531
GGML_OP_ARGSORT,
532532
GGML_OP_LEAKY_RELU,
533-
GGML_OP_SPARSEK_ATTN,
534533
GGML_OP_FLASH_ATTN_EXT,
535534
GGML_OP_FLASH_ATTN_BACK,
536535
GGML_OP_SSM_CONV,
@@ -2232,26 +2231,6 @@ extern "C" {
22322231
// n_head % ne32 == 0
22332232
// ne3 % ne33 == 0
22342233
//
2235-
2236-
GGML_API struct ggml_tensor * ggml_sparsek_attn(
2237-
struct ggml_context * ctx,
2238-
struct ggml_tensor * Q,
2239-
struct ggml_tensor * K,
2240-
struct ggml_tensor * V,
2241-
int32_t k_top,
2242-
int32_t win_local,
2243-
int32_t stride_global);
2244-
2245-
GGML_API void ggml_sparsek_attn_set_params(
2246-
struct ggml_tensor * a,
2247-
int32_t k_top,
2248-
int32_t win_local,
2249-
int32_t stride_global);
2250-
2251-
GGML_API int32_t ggml_sparsek_attn_get_param(
2252-
const struct ggml_tensor * a,
2253-
int index);
2254-
22552234
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
22562235
struct ggml_context * ctx,
22572236
struct ggml_tensor * q,
@@ -2281,6 +2260,20 @@ extern "C" {
22812260
struct ggml_tensor * v,
22822261
struct ggml_tensor * d,
22832262
bool masked);
2263+
// Optional SparseK parameters (disabled if use_sparsek=false)
2264+
GGML_API void ggml_flash_attn_ext_set_sparsek(
2265+
struct ggml_tensor * a,
2266+
bool use_sparsek,
2267+
int32_t k_top,
2268+
int32_t win_local,
2269+
int32_t stride_global);
2270+
2271+
GGML_API void ggml_flash_attn_ext_get_sparsek(
2272+
const struct ggml_tensor * a,
2273+
bool * use_sparsek,
2274+
int32_t * k_top,
2275+
int32_t * win_local,
2276+
int32_t * stride_global);
22842277

22852278
GGML_API struct ggml_tensor * ggml_ssm_conv(
22862279
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,10 +1947,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19471947
{
19481948
ggml_compute_forward_flash_attn_ext(params, tensor);
19491949
} break;
1950-
case GGML_OP_SPARSEK_ATTN:
1951-
{
1952-
ggml_compute_forward_sparsek_attn(params, tensor);
1953-
} break;
19541950
case GGML_OP_FLASH_ATTN_BACK:
19551951
{
19561952
int32_t t = ggml_get_op_params_i32(tensor, 0);

ggml/src/ggml-cpu/ops.cpp

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5107,6 +5107,14 @@ static void ggml_compute_forward_soft_max_f32(
51075107
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
51085108
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
51095109

5110+
// SparseK parameters (from op_params)
5111+
const bool use_sparsek = ggml_get_op_params_i32(dst, 30) != 0;
5112+
const int32_t k_top = ggml_get_op_params_i32(dst, 31);
5113+
const int32_t win_local = ggml_get_op_params_i32(dst, 32);
5114+
const int32_t stride_glb = ggml_get_op_params_i32(dst, 33);
5115+
(void)use_sparsek; (void)k_top; (void)win_local; (void)stride_glb;
5116+
5117+
51105118
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
51115119

51125120
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
@@ -8182,6 +8190,13 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
81828190
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
81838191
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
81848192

8193+
// -------- SparseK op_params (לא משנה שום דבר חוץ מקריאת הפרמטרים) --------
8194+
const bool use_sparsek = ggml_get_op_params_i32(dst, 30) != 0;
8195+
const int32_t k_top = ggml_get_op_params_i32(dst, 31);
8196+
const int32_t win_local = ggml_get_op_params_i32(dst, 32);
8197+
const int32_t stride_glb = ggml_get_op_params_i32(dst, 33);
8198+
// ----------------------------------------------------------------------------
8199+
81858200
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
81868201
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
81878202
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
@@ -8200,7 +8215,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82008215
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
82018216

82028217
const uint32_t h = iq2; // head index
8203-
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8218+
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
82048219

82058220
float S = 0.0f; // sum
82068221
float M = -INFINITY; // maximum KQ value
@@ -8229,18 +8244,51 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82298244
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
82308245
q_to_vec_dot(pq, Q_q, DK);
82318246

8232-
// online softmax / attention
8233-
// loop over n_kv and n_head_kv
8234-
// ref: https://arxiv.org/pdf/2112.05682.pdf
8235-
for (int64_t ic = 0; ic < nek1; ++ic) {
8247+
// ------------------------ SparseK: בניית רשימת מועמדים ------------------------
8248+
std::vector<int> cand_idx;
8249+
cand_idx.reserve((size_t)nek1);
8250+
8251+
if (!use_sparsek) {
8252+
for (int64_t t = 0; t < nek1; ++t) cand_idx.push_back((int)t);
8253+
} else {
8254+
for (int64_t t = 0; t < nek1; ++t) {
8255+
const int dist = std::abs((int)iq1 - (int)t);
8256+
const bool in_local = (win_local >= 0 && dist <= win_local);
8257+
const bool in_stride = (stride_glb > 1 && (t % stride_glb) == 0);
8258+
if (in_local || in_stride || t == iq1) cand_idx.push_back((int)t);
8259+
}
8260+
if (k_top > 0 && (int)cand_idx.size() > k_top) {
8261+
std::vector<float> vals; vals.reserve(cand_idx.size());
8262+
for (int idx : cand_idx) {
8263+
float tmp_s;
8264+
const char * k_data = (const char *) k->data + (idx*nbk1 + ik2*nbk2 + ik3*nbk3);
8265+
kq_vec_dot(DK, &tmp_s, 0, k_data, 0, Q_q, 0, 1);
8266+
vals.push_back(tmp_s * scale);
8267+
}
8268+
std::nth_element(vals.begin(), vals.begin() + (k_top - 1), vals.end(), std::greater<float>());
8269+
const float thr = vals[k_top - 1];
8270+
8271+
std::vector<int> filtered; filtered.reserve(k_top);
8272+
for (int idx : cand_idx) {
8273+
float tmp_s;
8274+
const char * k_data = (const char *) k->data + (idx*nbk1 + ik2*nbk2 + ik3*nbk3);
8275+
kq_vec_dot(DK, &tmp_s, 0, k_data, 0, Q_q, 0, 1);
8276+
if (tmp_s * scale >= thr) filtered.push_back(idx);
8277+
}
8278+
cand_idx.swap(filtered);
8279+
}
8280+
}
8281+
// ------------------------------------------------------------------------------
8282+
8283+
// ----- ליבת Flash Attention: אותו קוד, רק ריצה על cand_idx במקום כל ic -----
8284+
for (int ic : cand_idx) {
82368285
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
82378286
if (mv == -INFINITY) {
82388287
continue;
82398288
}
82408289

82418290
float s; // KQ value
8242-
8243-
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
8291+
const char * k_data = (const char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3);
82448292
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
82458293

82468294
s = s*scale; // scale KQ value
@@ -8260,44 +8308,33 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
82608308

82618309
if (v->type == GGML_TYPE_F16) {
82628310
if (s > M) {
8263-
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
82648311
M = s;
82658312
ms = expf(Mold - M);
8266-
8267-
// V = V*expf(Mold - M)
82688313
ggml_vec_scale_f16(DV, VKQ16, ms);
82698314
} else {
8270-
// no new maximum, ms == 1.0f, vs != 1.0f
82718315
vs = expf(s - M);
82728316
}
8273-
8274-
// V += v*expf(s - M)
82758317
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
82768318
} else {
82778319
if (s > M) {
8278-
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
82798320
M = s;
82808321
ms = expf(Mold - M);
8281-
8282-
// V = V*expf(Mold - M)
82838322
ggml_vec_scale_f32(DV, VKQ32, ms);
82848323
} else {
8285-
// no new maximum, ms == 1.0f, vs != 1.0f
82868324
vs = expf(s - M);
82878325
}
82888326

8289-
// V += v*expf(s - M)
82908327
if (v_to_float) {
82918328
v_to_float(v_data, V32, DV);
82928329
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
82938330
} else {
8294-
// V is F32
82958331
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
82968332
}
82978333
}
82988334

82998335
S = S*ms + vs; // scale and increment sum with partial sum
83008336
}
8337+
// ------------------------------------------------------------------------------
83018338

83028339
if (v->type == GGML_TYPE_F16) {
83038340
for (int64_t d = 0; d < DV; ++d) {
@@ -8331,9 +8368,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
83318368
const int i2 = iq2;
83328369
const int i3 = iq3;
83338370

8334-
// original
8335-
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8336-
83378371
// permute(0, 2, 1, 3)
83388372
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
83398373
}

ggml/src/ggml.c

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
990990
"TIMESTEP_EMBEDDING",
991991
"ARGSORT",
992992
"LEAKY_RELU",
993-
"SPARSEK_ATTN",
993+
994994
"FLASH_ATTN_EXT",
995995
"FLASH_ATTN_BACK",
996996
"SSM_CONV",
@@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10191019
"GLU",
10201020
};
10211021

1022-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
1022+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
10231023

10241024
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10251025
"none",
@@ -1094,7 +1094,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10941094
"timestep_embedding(timesteps, dim, max_period)",
10951095
"argsort(x)",
10961096
"leaky_relu(x)",
1097-
"sparsek_attn(x)",
10981097
"flash_attn_ext(x)",
10991098
"flash_attn_back(x)",
11001099
"ssm_conv(x)",
@@ -1123,7 +1122,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11231122
"glu(x)",
11241123
};
11251124

1126-
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");
1125+
static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
11271126

11281127
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11291128

@@ -5063,52 +5062,6 @@ struct ggml_tensor * ggml_top_k(
50635062
return result;
50645063
}
50655064

5066-
// ggml_sparsek_attn
5067-
struct ggml_tensor * ggml_sparsek_attn(
5068-
struct ggml_context * ctx,
5069-
struct ggml_tensor * Q,
5070-
struct ggml_tensor * K,
5071-
struct ggml_tensor * V,
5072-
int32_t k_top,
5073-
int32_t win_local,
5074-
int32_t stride_global) {
5075-
5076-
GGML_ASSERT(ggml_can_mul_mat(K, Q));
5077-
GGML_ASSERT(Q->ne[3] == K->ne[3] && Q->ne[3] == V->ne[3]);
5078-
5079-
int64_t ne[4] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] };
5080-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5081-
5082-
5083-
int32_t params_i32[3] = { k_top, win_local, stride_global };
5084-
ggml_set_op_params(result, params_i32, sizeof(params_i32));
5085-
5086-
result->op = GGML_OP_SPARSEK_ATTN;
5087-
result->src[0] = Q;
5088-
result->src[1] = K;
5089-
result->src[2] = V;
5090-
5091-
return result;
5092-
}
5093-
5094-
5095-
void ggml_sparsek_attn_set_params(struct ggml_tensor * a,
5096-
int32_t k_top,
5097-
int32_t win_local,
5098-
int32_t stride_global) {
5099-
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
5100-
ggml_set_op_params_i32(a, 0, k_top);
5101-
ggml_set_op_params_i32(a, 1, win_local);
5102-
ggml_set_op_params_i32(a, 2, stride_global);
5103-
}
5104-
5105-
int32_t ggml_sparsek_attn_get_param(const struct ggml_tensor * a, int index) {
5106-
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
5107-
return ggml_get_op_params_i32(a, index);
5108-
}
5109-
5110-
5111-
51125065
// ggml_flash_attn_ext
51135066

51145067
struct ggml_tensor * ggml_flash_attn_ext(
@@ -5262,6 +5215,40 @@ struct ggml_tensor * ggml_flash_attn_back(
52625215
return result;
52635216
}
52645217

5218+
#define GGML_FA_EXT_PARAM_SPARSEK_FLAG 30
5219+
#define GGML_FA_EXT_PARAM_SPARSEK_KTOP 31
5220+
#define GGML_FA_EXT_PARAM_SPARSEK_WIN 32
5221+
#define GGML_FA_EXT_PARAM_SPARSEK_STRIDE 33
5222+
5223+
void ggml_flash_attn_ext_set_sparsek(struct ggml_tensor * a,
5224+
bool use_sparsek,
5225+
int32_t k_top,
5226+
int32_t win_local,
5227+
int32_t stride_global) {
5228+
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
5229+
a->op_params[GGML_FA_EXT_PARAM_SPARSEK_FLAG] = use_sparsek ? 1 : 0;
5230+
a->op_params[GGML_FA_EXT_PARAM_SPARSEK_KTOP] = k_top;
5231+
a->op_params[GGML_FA_EXT_PARAM_SPARSEK_WIN] = win_local;
5232+
a->op_params[GGML_FA_EXT_PARAM_SPARSEK_STRIDE] = stride_global;
5233+
}
5234+
5235+
void ggml_flash_attn_ext_get_sparsek(const struct ggml_tensor * a,
5236+
bool * use_sparsek,
5237+
int32_t * k_top,
5238+
int32_t * win_local,
5239+
int32_t * stride_global) {
5240+
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
5241+
if (use_sparsek)
5242+
*use_sparsek = a->op_params[GGML_FA_EXT_PARAM_SPARSEK_FLAG] != 0;
5243+
if (k_top)
5244+
*k_top = a->op_params[GGML_FA_EXT_PARAM_SPARSEK_KTOP];
5245+
if (win_local)
5246+
*win_local = a->op_params[GGML_FA_EXT_PARAM_SPARSEK_WIN];
5247+
if (stride_global)
5248+
*stride_global = a->op_params[GGML_FA_EXT_PARAM_SPARSEK_STRIDE];
5249+
}
5250+
5251+
52655252
// ggml_ssm_conv
52665253

52675254
struct ggml_tensor * ggml_ssm_conv(

tests/test-backend-ops.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5513,12 +5513,21 @@ struct test_sparsek_attn : public test_case {
55135513
ggml_set_name(K, "K");
55145514
ggml_tensor * V = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_v, n_tokens, n_head, batch);
55155515
ggml_set_name(V, "V");
5516+
// ----------------------------------------------------------------------------
5517+
// SparseK Attention test (integrated via FlashAttention extension)
5518+
// ----------------------------------------------------------------------------
5519+
ggml_tensor * mask = NULL;
5520+
float scale = 1.0f;
5521+
float max_bias = 0.0f;
5522+
float bias = 0.0f;
5523+
ggml_tensor * out = ggml_flash_attn_ext(ctx, Q, K, V, mask, scale, max_bias, bias);
5524+
ggml_flash_attn_ext_set_sparsek(out, true, k_top, win_local, stride_global);
5525+
5526+
ggml_set_name(out, "FLASH_ATTN_EXT_with_SPARSEK");
5527+
return out;
55165528

5517-
ggml_tensor * out = ggml_sparsek_attn(ctx, Q, K, V, k_top, win_local, stride_global);
5518-
ggml_set_name(out, "SPARSEK_ATTN_out");
55195529

5520-
return out;
5521-
}
5530+
}
55225531
};
55235532

55245533

0 commit comments

Comments
 (0)