Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/ops/CANN.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3292,10 +3292,10 @@
"CANN0","RWKV_WKV7","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","0","no","CANN"
"CANN0","RWKV_WKV7","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","0","no","CANN"
"CANN0","RWKV_WKV7","type=f32,head_count=32,head_size=64,n_seq_tokens=128,n_seqs=4","support","0","no","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","0","no","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","0","no","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","0","no","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=128,n_seqs=4","support","0","no","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","CANN"
"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=128,n_seqs=4","support","1","yes","CANN"
"CANN0","MUL_MAT","type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0","support","1","yes","CANN"
"CANN0","MUL_MAT","type_a=f32,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0","support","1","yes","CANN"
"CANN0","MUL_MAT","type_a=f32,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0","support","1","yes","CANN"
Expand Down
110 changes: 110 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <aclnnop/aclnn_matmul.h>
#include <aclnnop/aclnn_max_pool.h>
#include <aclnnop/aclnn_mm.h>
#include <aclnnop/aclnn_mv.h>
#include <aclnnop/aclnn_permute.h>
#include <aclnnop/aclnn_pow_tensor_tensor.h>
#include <aclnnop/aclnn_reduce_sum.h>
Expand Down Expand Up @@ -439,6 +440,115 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_cann_release_resources(ctx, norm, acl_src, acl_dst);
}

void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// 获取输入张量
ggml_tensor * k = dst->src[0];
ggml_tensor * v = dst->src[1];
ggml_tensor * q = dst->src[2];
ggml_tensor * g = dst->src[3];
ggml_tensor * s = dst->src[4];

// 计算维度参数
int64_t B = dst->src[4]->ne[1]; // Batch size
int64_t T = dst->src[0]->ne[2]; // Total sequence length
int64_t H = dst->src[0]->ne[1]; // Number of heads
int64_t C = dst->ne[0]; // Total channels
int64_t D = C / H; // Dimensionality per head
int64_t L = T / B; // Sequence length per batch

// 设置张量维度和步长信息
int64_t ne_qkg[2] = {1, D}; // k/g的形状 [1,D]
int64_t ne_s[2] = {D, D}; // 状态张量形状 [D,D]
int64_t ne_vo[2] = {D, 1}; // v的形状 [D,1]
int64_t ne_q[1] = {D}; // q/o的形状 [D]

// 计算步长(内存布局)
size_t nb_base = ggml_type_size(k->type);
size_t nb_qkg[2] = {nb_base, nb_base};
size_t nb_s[2] = {nb_base, D * nb_base};
size_t nb_vo[2] = {nb_base, D * nb_base};
size_t nb_q[1] = {nb_base};

// 获取缩放因子
float scale;
memcpy(&scale, dst->op_params, sizeof(float));

// 预分配缓冲区,避免在循环中重复分配(性能优化1)
size_t buf_size = D * D * sizeof(float);
ggml_cann_pool_alloc state_buf1(ctx.pool(), buf_size);
void* buf1_ptr = state_buf1.get();
ggml_cann_pool_alloc state_buf2(ctx.pool(), buf_size);
void* buf2_ptr = state_buf2.get();

// 创建可重用的缓冲区张量(性能优化2)
aclTensor* acl_buf_k = ggml_cann_create_tensor(buf1_ptr, ggml_cann_type_mapping(k->type),
ggml_type_size(k->type), ne_s, nb_s, 2);
aclTensor* acl_buf_v = ggml_cann_create_tensor(buf2_ptr, ggml_cann_type_mapping(k->type),
ggml_type_size(k->type), ne_s, nb_s, 2);

// 预创建重复参数数组(性能优化3)
int64_t k_rep[2] = {1, D}; // k/g重复模式 [1,D] -> [D,D]
int64_t v_rep[2] = {D, 1}; // v重复模式 [D,1] -> [D,D]
aclIntArray* acl_k_rep = aclCreateIntArray(k_rep, 2);
aclIntArray* acl_v_rep = aclCreateIntArray(v_rep, 2);

// 定义转置维度
int64_t newdim[2] = {1, 0}; // [D,D] -> [D,D] (转置)

// 遍历批次、头和时间步
for (int64_t b = 0; b < B; b++) {
for (int64_t h = 0; h < H; h++) {
// 计算状态张量的偏移量
size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;

// 创建状态张量
aclTensor* acl_s = ggml_cann_create_tensor(s, ne_s, nb_s, 2, ACL_FORMAT_ND, s_offset);
aclTensor* acl_s_new = ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND,
(B * L * H * D) * nb_base + s_offset);

// 复制初始状态
cann_copy(ctx, acl_s, acl_s_new);

// 遍历时间步,更新状态并计算输出
for (int64_t l = 0; l < L; l++) {
// 计算当前时间步的qkvgo偏移量
size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;

// 创建当前时间步所需的张量
aclTensor* acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
aclTensor* acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
aclTensor* acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
aclTensor* acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
aclTensor* acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);

// 1. 计算k*v外积
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_k, acl_k_rep, acl_buf_k); // k广播到[D,D]
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_v, acl_v_rep, acl_buf_v); // v广播到[D,D]
aclnn_mul(ctx, acl_buf_k, acl_buf_v, nullptr); // 元素级乘法 k*v

// 2. 应用门控并更新状态
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_g, acl_k_rep, acl_buf_v); // g广播到[D,D]
aclnn_mul(ctx, acl_s_new, acl_buf_v, nullptr); // 门控操作: s = s * g
aclnn_add(ctx, acl_s_new, acl_buf_k, nullptr); // 状态更新: s = s + k*v

// 3. 计算输出
aclnn_permute(ctx, acl_s_new, acl_buf_k, newdim, 2); // 转置状态矩阵
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_buf_k, acl_q, acl_o, 1); // 矩阵向量乘法: o = s^T * q
aclnn_muls(ctx, acl_o, scale, nullptr, true); // 应用缩放因子

// 释放当前时间步的临时张量
ggml_cann_release_resources(ctx, acl_q, acl_k, acl_v, acl_o, acl_g);
}

// 释放状态张量
ggml_cann_release_resources(ctx, acl_s, acl_s_new);
}
}

// 释放预分配的资源
ggml_cann_release_resources(ctx, acl_buf_k, acl_buf_v, acl_k_rep, acl_v_rep);
}

void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src = dst->src[0];

Expand Down
16 changes: 16 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*/
void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/**
* @brief Computes the Gated Linear Attention for a ggml tensor using the CANN
* backend.
*
* @details ...
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the normalized values will be stored.
* @attention ...
*/
void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/**
* @brief Computes the Group Normalization for a ggml tensor using the CANN
* backend.
Expand Down Expand Up @@ -605,6 +617,10 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst);

static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst);
static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst, int64_t* new_dim, uint64_t dims);

/**
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one
* output tensor.
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cann_gated_linear_attn(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -2493,6 +2496,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_MEAN:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_SCALE:
float bias;
Expand Down