Skip to content

Commit 004f090

Browse files
author
赵禹昇
committed
support gated linear attn
1 parent 74f52f7 commit 004f090

File tree

3 files changed

+108
-0
lines changed

3 files changed

+108
-0
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <aclnnop/aclnn_matmul.h>
3939
#include <aclnnop/aclnn_max_pool.h>
4040
#include <aclnnop/aclnn_mm.h>
41+
#include <aclnnop/aclnn_mv.h>
4142
#include <aclnnop/aclnn_permute.h>
4243
#include <aclnnop/aclnn_pow_tensor_tensor.h>
4344
#include <aclnnop/aclnn_reduce_sum.h>
@@ -439,6 +440,93 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
439440
ggml_cann_release_resources(ctx, norm, acl_src, acl_dst);
440441
}
441442

443+
void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
444+
ggml_tensor * k = dst->src[0];
445+
ggml_tensor * v = dst->src[1];
446+
ggml_tensor * q = dst->src[2];
447+
ggml_tensor * g = dst->src[3];
448+
ggml_tensor * s = dst->src[4];
449+
450+
int64_t B = dst->src[4]->ne[1];
451+
int64_t T = dst->src[0]->ne[2];
452+
int64_t H = dst->src[0]->ne[1];
453+
int64_t C = dst->ne[0];
454+
int64_t D = C / H;
455+
int64_t L = T / B;
456+
457+
int64_t ne_qkg[2] = {1, D};
458+
// int64_t ne_qkg[2] = {D, 1};
459+
int64_t ne_s[2] = {D, D};
460+
int64_t ne_vo[2] = {D, 1};
461+
// int64_t ne_vo[2] = {1, D};
462+
int64_t ne_q[1] = {D};
463+
size_t nb_base = ggml_type_size(k->type);
464+
size_t nb_qkg[2] = {nb_base, nb_base};
465+
size_t nb_s[2] = {nb_base, D * nb_base};
466+
size_t nb_vo[2] = {nb_base, D * nb_base};
467+
size_t nb_q[1] = {nb_base};
468+
469+
float scale;
470+
memcpy(&scale, dst->op_params, sizeof(float));
471+
472+
for (int64_t b = 0; b < B; b++) {
473+
for (int64_t h = 0; h < H; h++) {
474+
size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
475+
// D * D
476+
aclTensor* acl_s = ggml_cann_create_tensor(s, ne_s, nb_s, 2, ACL_FORMAT_ND, s_offset);
477+
aclTensor* acl_s_new = ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
478+
cann_copy(ctx, acl_s, acl_s_new);
479+
for (int64_t l = 0; l < L; l++) {
480+
size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
481+
// D * 1
482+
aclTensor* acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
483+
aclTensor* acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
484+
// D
485+
aclTensor* acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
486+
// 1 * D
487+
aclTensor* acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
488+
// D
489+
aclTensor* acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
490+
// repeat k and v
491+
// buffer for repeated k
492+
size_t buf_size = D * D * sizeof(float);
493+
ggml_cann_pool_alloc state_buf1(ctx.pool(), buf_size);
494+
void* buf1_ptr = state_buf1.get();
495+
aclTensor* acl_buf_k = ggml_cann_create_tensor(buf1_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2);
496+
// buffer for repeated v
497+
ggml_cann_pool_alloc state_buf2(ctx.pool(), buf_size);
498+
void* buf2_ptr = state_buf2.get();
499+
aclTensor* acl_buf_v = ggml_cann_create_tensor(buf2_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2);
500+
// repeat
501+
int64_t k_rep[2] = {1, D};
502+
int64_t v_rep[2] = {D, 1};
503+
// int64_t k_rep[2] = {D, 1};
504+
// int64_t v_rep[2] = {1, D};
505+
aclIntArray* acl_k_rep = aclCreateIntArray(k_rep, 2);
506+
aclIntArray* acl_v_rep = aclCreateIntArray(v_rep, 2);
507+
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_k, acl_k_rep, acl_buf_k);
508+
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_v, acl_v_rep, acl_buf_v);
509+
// inplace mul, saved in acl_buf_k
510+
aclnn_mul(ctx, acl_buf_k, acl_buf_v, nullptr);
511+
// apply g to s
512+
// reuse acl_buf_v to store repeated g
513+
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_g, acl_k_rep, acl_buf_v);
514+
aclnn_mul(ctx, acl_s_new, acl_buf_v, nullptr);
515+
// add kv
516+
aclnn_add(ctx, acl_s_new, acl_buf_k, nullptr);
517+
// compute output
518+
// permute state and store in acl_buf k
519+
int64_t newdim[2] = {1, 0};
520+
aclnn_permute(ctx, acl_s_new, acl_buf_k, newdim, 2);
521+
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_buf_k, acl_q, acl_o, 1);
522+
aclnn_muls(ctx, acl_o, scale, nullptr, true);
523+
ggml_cann_release_resources(ctx, acl_q, acl_k, acl_v, acl_o, acl_g, acl_buf_k, acl_buf_v, acl_k_rep, acl_v_rep);
524+
}
525+
ggml_cann_release_resources(ctx, acl_s, acl_s_new);
526+
}
527+
}
528+
}
529+
442530
void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
443531
ggml_tensor* src = dst->src[0];
444532

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
187187
*/
188188
void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
189189

190+
/**
191+
* @brief Computes the Gated Linear Attention for a ggml tensor using the CANN
192+
* backend.
193+
*
194+
* @details ...
195+
*
196+
* @param ctx The CANN context used for operations.
197+
* @param dst The destination tensor where the normalized values will be stored.
198+
* @attention ...
199+
*/
200+
void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst);
201+
190202
/**
191203
* @brief Computes the Group Normalization for a ggml tensor using the CANN
192204
* backend.
@@ -605,6 +617,10 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
605617
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
606618
aclTensor* acl_dst);
607619

620+
static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst);
621+
static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
622+
aclTensor* acl_dst, int64_t* new_dim, uint64_t dims);
623+
608624
/**
609625
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one
610626
* output tensor.

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,6 +1881,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
18811881
case GGML_OP_FLASH_ATTN_EXT:
18821882
ggml_cann_flash_attn_ext(ctx, dst);
18831883
break;
1884+
case GGML_OP_GATED_LINEAR_ATTN:
1885+
ggml_cann_gated_linear_attn(ctx, dst);
1886+
break;
18841887
default:
18851888
return false;
18861889
}
@@ -2493,6 +2496,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
24932496
case GGML_OP_MEAN:
24942497
case GGML_OP_PAD_REFLECT_1D:
24952498
case GGML_OP_COUNT_EQUAL:
2499+
case GGML_OP_GATED_LINEAR_ATTN:
24962500
return true;
24972501
case GGML_OP_SCALE:
24982502
float bias;

0 commit comments

Comments
 (0)