Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,

GGML_OP_SPARSEK_ATTN,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_CONV,
Expand Down Expand Up @@ -2231,6 +2231,16 @@ extern "C" {
// n_head % ne32 == 0
// ne3 % ne33 == 0
//

GGML_API struct ggml_tensor * ggml_sparsek_attn(
struct ggml_context * ctx,
struct ggml_tensor * Q,
struct ggml_tensor * K,
struct ggml_tensor * V,
int32_t k_top,
int32_t win_local,
int32_t stride_global);

GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_flash_attn_ext(params, tensor);
} break;
case GGML_OP_SPARSEK_ATTN:
{
ggml_compute_forward_sparsek_attn(params, tensor);
break;
}
case GGML_OP_FLASH_ATTN_BACK:
{
int32_t t = ggml_get_op_params_i32(tensor, 0);
Expand Down
82 changes: 82 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7907,6 +7907,88 @@ void ggml_compute_forward_argsort(
}
}

//------------------------------------------------------------------------------
// SparseK Attention (CPU)
//------------------------------------------------------------------------------

static void ggml_compute_forward_sparsek_attn_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

if (params->ith != 0) return; // main thread only

const struct ggml_tensor * Q = dst->src[0];
const struct ggml_tensor * K = dst->src[1];
const struct ggml_tensor * V = dst->src[2];

GGML_ASSERT(Q && K && V);
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F32);
GGML_ASSERT(V->type == GGML_TYPE_F32);

const int32_t k_top = ggml_get_op_params_i32(dst, 0);
const int32_t win_local = ggml_get_op_params_i32(dst, 1);
const int32_t stride_glb = ggml_get_op_params_i32(dst, 2);

const int64_t D = Q->ne[0]; // embedding dim
const int64_t T = Q->ne[1]; // sequence length

const float * q = (const float *) Q->data;
const float * k = (const float *) K->data;
const float * v = (const float *) V->data;
float * out = (float *) dst->data;


for (int64_t i = 0; i < T; ++i) {
for (int64_t j = 0; j < T; ++j) {
float dot = 0.0f;
for (int64_t d = 0; d < D; ++d)
dot += q[i*D + d] * k[j*D + d];
out[i*T + j] = dot / sqrtf((float) D);
}
}

for (int64_t i = 0; i < T; ++i) {
float * row = &out[i*T];
for (int64_t j = 0; j < T; ++j)
if (row[j] < row[k_top]) row[j] = -INFINITY;
}

for (int64_t i = 0; i < T; ++i) {
float maxv = -INFINITY;
for (int64_t j = 0; j < T; ++j)
if (out[i*T + j] > maxv) maxv = out[i*T + j];
float sum = 0.0f;
for (int64_t j = 0; j < T; ++j) {
out[i*T + j] = expf(out[i*T + j] - maxv);
sum += out[i*T + j];
}
for (int64_t j = 0; j < T; ++j)
out[i*T + j] /= sum;
}


float * result = (float *) dst->data;
for (int64_t i = 0; i < T; ++i) {
for (int64_t d = 0; d < D; ++d) {
float sum = 0.0f;
for (int64_t j = 0; j < T; ++j)
sum += out[i*T + j] * v[j*D + d];
result[i*D + d] = sum;
}
}

GGML_PRINT_DEBUG("[SPARSEK CPU] k_top=%d win_local=%d stride=%d\n",
k_top, win_local, stride_glb);
}

void ggml_compute_forward_sparsek_attn(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
ggml_compute_forward_sparsek_attn_f32(params, dst);
}


// ggml_compute_forward_flash_attn_ext

static void ggml_compute_forward_flash_attn_ext_f16(
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params *
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sparsek_attn(const struct ggml_compute_params * params, struct ggml_tensor * dst);

void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
const bool masked,
Expand Down
54 changes: 50 additions & 4 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"LEAKY_RELU",

"SPARSEK_ATTN",
"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
"SSM_CONV",
Expand Down Expand Up @@ -1019,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1094,7 +1094,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"leaky_relu(x)",

"sparsek_attn(x)",
"flash_attn_ext(x)",
"flash_attn_back(x)",
"ssm_conv(x)",
Expand Down Expand Up @@ -1123,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -5063,6 +5063,52 @@ struct ggml_tensor * ggml_top_k(
return result;
}

// ggml_sparsek_attn
struct ggml_tensor * ggml_sparsek_attn(
struct ggml_context * ctx,
struct ggml_tensor * Q,
struct ggml_tensor * K,
struct ggml_tensor * V,
int32_t k_top,
int32_t win_local,
int32_t stride_global) {

GGML_ASSERT(ggml_can_mul_mat(K, Q));
GGML_ASSERT(Q->ne[3] == K->ne[3] && Q->ne[3] == V->ne[3]);

int64_t ne[4] = { V->ne[0], Q->ne[2], Q->ne[1], Q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);


int32_t params_i32[3] = { k_top, win_local, stride_global };
ggml_set_op_params(result, params_i32, sizeof(params_i32));

result->op = GGML_OP_SPARSEK_ATTN;
result->src[0] = Q;
result->src[1] = K;
result->src[2] = V;

return result;
}


void ggml_sparsek_attn_set_params(struct ggml_tensor * a,
int32_t k_top,
int32_t win_local,
int32_t stride_global) {
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
ggml_set_op_params_i32(a, 0, k_top);
ggml_set_op_params_i32(a, 1, win_local);
ggml_set_op_params_i32(a, 2, stride_global);
}

int32_t ggml_sparsek_attn_get_param(const struct ggml_tensor * a, int index) {
GGML_ASSERT(a->op == GGML_OP_SPARSEK_ATTN);
return ggml_get_op_params_i32(a, index);
}



// ggml_flash_attn_ext

struct ggml_tensor * ggml_flash_attn_ext(
Expand Down
61 changes: 59 additions & 2 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,7 @@ struct test_example : public test_case {
};



// GGML_OP_UNARY
struct test_unary : public test_case {
const ggml_unary_op op;
Expand Down Expand Up @@ -5362,7 +5363,46 @@ struct test_leaky_relu : public test_case {
}
};

// GGML_OP_FLASH_ATTN_EXT
// GGML_OP_SPARSEK_ATTN
struct test_sparsek_attn : public test_case {
const int64_t d_qk;
const int64_t d_v;
const int64_t n_head;
const int64_t n_tokens;
const int64_t batch;
const int32_t k_top;
const int32_t win_local;
const int32_t stride_global;

std::string vars() override {
return VARS_TO_STR9(d_qk, d_v, n_head, n_tokens, batch, k_top, win_local, stride_global, 0);
}

test_sparsek_attn(int64_t d_qk = 128, int64_t d_v = 128, int64_t n_head = 8,
int64_t n_tokens = 256, int64_t batch = 4,
int32_t k_top = 32, int32_t win_local = 64, int32_t stride_global = 128)
: d_qk(d_qk), d_v(d_v), n_head(n_head), n_tokens(n_tokens), batch(batch),
k_top(k_top), win_local(win_local), stride_global(stride_global) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_q = n_tokens;
ggml_tensor * Q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_q, n_head, batch);
ggml_set_name(Q, "Q");
ggml_tensor * K = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_qk, n_tokens, n_head, batch);
ggml_set_name(K, "K");
ggml_tensor * V = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_v, n_tokens, n_head, batch);
ggml_set_name(V, "V");

ggml_tensor * out = ggml_sparsek_attn(ctx, Q, K, V, k_top, win_local, stride_global);
ggml_set_name(out, "SPARSEK_ATTN_out");

return out;
}
};



// GGML_OP_FLAsH_ATTN_EXT
struct test_flash_attn_ext : public test_case {
const int64_t hsk; // K head size
const int64_t hsv; // V head size
Expand Down Expand Up @@ -7095,7 +7135,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA

for (bool mask : { true, false } ) {
for (bool sinks : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) {
Expand Down Expand Up @@ -7134,6 +7174,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
}
// ---- SPARSEK_ATTN --------------------------------------------------
for (int64_t d_qk : {64, 128}) {
for (int64_t d_v : {64, 128}) {
for (int64_t n_head : {4, 8}) {
for (int64_t kv : {113, 512}) {
for (int64_t b : {1, 4}) {
for (int32_t k_top : {16, 32}) {
for (int32_t win_local : {32, 64}) {
test_cases.emplace_back(new test_sparsek_attn(
d_qk, d_v, n_head, kv, b, k_top, win_local, /*stride_global*/128));
}
}
}
}
}
}
}

test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, { 10, 5, 4, 3}));
test_cases.emplace_back(new test_cross_entropy_loss (GGML_TYPE_F32, {30000, 1, 1, 1}));
Expand Down