diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d948b00cc..c47c5404c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -529,7 +529,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, - + GGML_OP_SPARSEK_ATTN, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, @@ -2231,6 +2231,16 @@ extern "C" { // n_head % ne32 == 0 // ne3 % ne33 == 0 // + + GGML_API struct ggml_tensor * ggml_sparsek_attn( + struct ggml_context * ctx, + struct ggml_tensor * Q, + struct ggml_tensor * K, + struct ggml_tensor * V, + int32_t k_top, + int32_t win_local, + int32_t stride_global); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 9ec485cfa..b43a2b437 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1952,6 +1952,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_flash_attn_ext(params, tensor); } break; + case GGML_OP_SPARSEK_ATTN: + { + ggml_compute_forward_sparsek_attn(params, tensor); + break; + } case GGML_OP_FLASH_ATTN_BACK: { int32_t t = ggml_get_op_params_i32(tensor, 0); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3156bd601..5bc0cb3e2 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7907,6 +7907,88 @@ void ggml_compute_forward_argsort( } } +//------------------------------------------------------------------------------ +// SparseK Attention (CPU) +//------------------------------------------------------------------------------ + +static void ggml_compute_forward_sparsek_attn_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + if (params->ith != 0) return; // main thread only + + const struct ggml_tensor * Q = dst->src[0]; + const struct ggml_tensor * K = dst->src[1]; + const struct ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(Q && K && V); + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + + const int32_t k_top = ggml_get_op_params_i32(dst, 0); + const int32_t win_local = ggml_get_op_params_i32(dst, 1); + const int32_t stride_glb = ggml_get_op_params_i32(dst, 2); + + const int64_t D = Q->ne[0]; // embedding dim + const int64_t T = Q->ne[1]; // sequence length + + const float * q = (const float *) Q->data; + const float * k = (const float *) K->data; + const float * v = (const float *) V->data; + float * out = (float *) dst->data; + + + for (int64_t i = 0; i < T; ++i) { + for (int64_t j = 0; j < T; ++j) { + float dot = 0.0f; + for (int64_t d = 0; d < D; ++d) + dot += q[i*D + d] * k[j*D + d]; + out[i*T + j] = dot / sqrtf((float) D); + } + } + + for (int64_t i = 0; i < T; ++i) { + float * row = &out[i*T]; + for (int64_t j = 0; j < T; ++j) + if (row[j] < row[k_top]) row[j] = -INFINITY; + } + + for (int64_t i = 0; i < T; ++i) { + float maxv = -INFINITY; + for (int64_t j = 0; j < T; ++j) + if (out[i*T + j] > maxv) maxv = out[i*T + j]; + float sum = 0.0f; + for (int64_t j = 0; j < T; ++j) { + out[i*T + j] = expf(out[i*T + j] - maxv); + sum += out[i*T + j]; + } + for (int64_t j = 0; j < T; ++j) + out[i*T + j] /= sum; + } + + + float * result = (float *) dst->data; + for (int64_t i = 0; i < T; ++i) { + for (int64_t d = 0; d < D; ++d) { + float sum = 0.0f; + for (int64_t j = 0; j < T; ++j) + sum += out[i*T + j] * v[j*D + d]; + result[i*D + d] = sum; + } + } + + GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n", + k_top, win_local, stride_glb); +} + +void ggml_compute_forward_sparsek_attn( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + ggml_compute_forward_sparsek_attn_f32(params, dst); +} + + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 9824a03b4..e43b23a55 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -86,6 +86,8 @@ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_sparsek_attn(const struct ggml_compute_params * params, struct ggml_tensor * dst); + void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, const bool masked, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9be35c1be..9ad055c99 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -990,7 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", - + "SPARSEK_ATTN", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", "SSM_CONV", @@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", - + "sparsek_attn(x)", "flash_attn_ext(x)", "flash_attn_back(x)", "ssm_conv(x)", @@ -1123,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); +static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5063,6 +5063,52 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_sparsek_attn +struct ggml_tensor * ggml_sparsek_attn( + struct ggml_context * ctx, + struct ggml_tensor * Q, + struct ggml_tensor * K, + struct ggml_tensor * V, + int32_t k_top, + int32_t win_local, + int32_t stride_global) { + + GGML_ASSERT(ggml_can_mul_mat(K, Q)); + GGML_ASSERT(Q->ne[3] == K->ne[3] && Q->ne[3] == V->ne[3]); + + int64_t ne[4] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + + int32_t params_i32[3] = { k_top, win_local, stride_global }; + ggml_set_op_params(result, params_i32, sizeof(params_i32)); + + result->op = GGML_OP_SPARSEK_ATTN; + result->src[0] = Q; + result->src[1] = K; + result->src[2] = V; + + return result; +} + + +void ggml_sparsek_attn_set_params(struct ggml_tensor * a, + int32_t k_top, + int32_t win_local, + int32_t stride_global) { + GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN); + ggml_set_op_params_i32(a, 0, k_top); + ggml_set_op_params_i32(a, 1, win_local); + ggml_set_op_params_i32(a, 2, stride_global); +} + +int32_t ggml_sparsek_attn_get_param(const struct ggml_tensor * a, int index) { + GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN); + return ggml_get_op_params_i32(a, index); +} + + + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aee173013..e899bb8c5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1778,6 +1778,7 @@ struct test_example : public test_case { }; + // GGML_OP_UNARY struct test_unary : public test_case { const ggml_unary_op op; @@ -5362,7 +5363,46 @@ struct test_leaky_relu : public test_case { } }; -// GGML_OP_FLASH_ATTN_EXT +// GGML_OP_SPARSEK_ATTN +struct test_sparsek_attn : public test_case { + const int64_t d_qk; + const int64_t d_v; + const int64_t n_head; + const int64_t n_tokens; + const int64_t batch; + const int32_t k_top; + const int32_t win_local; + const int32_t stride_global; + + std::string vars() override { + return VARS_TO_STR9(d_qk, d_v, n_head, n_tokens, batch, k_top, win_local, stride_global, 0); + } + + test_sparsek_attn(int64_t d_qk = 128, int64_t d_v = 128, int64_t n_head = 8, + int64_t n_tokens = 256, int64_t batch = 4, + int32_t k_top = 32, int32_t win_local = 64, int32_t stride_global = 128) + : d_qk(d_qk), d_v(d_v), n_head(n_head), n_tokens(n_tokens), batch(batch), + k_top(k_top), win_local(win_local), stride_global(stride_global) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + const int64_t n_q = n_tokens; + ggml_tensor * Q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_q, n_head, batch); + ggml_set_name(Q, "Q"); + ggml_tensor * K = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_tokens, n_head, batch); + ggml_set_name(K, "K"); + ggml_tensor * V = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_v, n_tokens, n_head, batch); + ggml_set_name(V, "V"); + + ggml_tensor * out = ggml_sparsek_attn(ctx, Q, K, V, k_top, win_local, stride_global); + ggml_set_name(out, "SPARSEK_ATTN_out"); + + return out; + } +}; + + + +// GGML_OP_FLAsH_ATTN_EXT struct test_flash_attn_ext : public test_case { const int64_t hsk; // K head size const int64_t hsv; // V head size @@ -7095,7 +7135,7 @@ static std::vector> make_test_cases_eval() { if (hsk != 192 && hsk != 576 && hsk != hsv) continue; if (hsk == 192 && (hsv != 128 && hsv != 192)) continue; if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA - + for (bool mask : { true, false } ) { for (bool sinks : { true, false } ) { for (float max_bias : { 0.0f, 8.0f }) { @@ -7134,6 +7174,23 @@ static std::vector> make_test_cases_eval() { } } } + // ---- SPARSEK_ATTN -------------------------------------------------- + for (int64_t d_qk : {64, 128}) { + for (int64_t d_v : {64, 128}) { + for (int64_t n_head : {4, 8}) { + for (int64_t kv : {113, 512}) { + for (int64_t b : {1, 4}) { + for (int32_t k_top : {16, 32}) { + for (int32_t win_local : {32, 64}) { + test_cases.emplace_back(new test_sparsek_attn( + d_qk, d_v, n_head, kv, b, k_top, win_local, /*stride_global*/128)); + } + } + } + } + } + } + } test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3})); test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));