diff --git a/docs/ops/CANN.csv b/docs/ops/CANN.csv index 0ac1078304a..8170c9755a7 100644 --- a/docs/ops/CANN.csv +++ b/docs/ops/CANN.csv @@ -3292,10 +3292,10 @@ "CANN0","RWKV_WKV7","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","0","no","CANN" "CANN0","RWKV_WKV7","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","0","no","CANN" "CANN0","RWKV_WKV7","type=f32,head_count=32,head_size=64,n_seq_tokens=128,n_seqs=4","support","0","no","CANN" -"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","0","no","CANN" -"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","0","no","CANN" -"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","0","no","CANN" -"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=128,n_seqs=4","support","0","no","CANN" +"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","CANN" +"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","CANN" +"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","CANN" +"CANN0","GATED_LINEAR_ATTN","type=f32,head_count=32,head_size=64,n_seq_tokens=128,n_seqs=4","support","1","yes","CANN" "CANN0","MUL_MAT","type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0","support","1","yes","CANN" "CANN0","MUL_MAT","type_a=f32,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0","support","1","yes","CANN" "CANN0","MUL_MAT","type_a=f32,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0","support","1","yes","CANN" diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index bc33b99d96e..a33d6507a4b 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -439,6 +440,115 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_release_resources(ctx, norm, acl_src, acl_dst); } +void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + // 获取输入张量 + ggml_tensor * k = dst->src[0]; + ggml_tensor * v = dst->src[1]; + ggml_tensor * q = dst->src[2]; + ggml_tensor * g = dst->src[3]; + ggml_tensor * s = dst->src[4]; + + // 计算维度参数 + int64_t B = dst->src[4]->ne[1]; // Batch size + int64_t T = dst->src[0]->ne[2]; // Total sequence length + int64_t H = dst->src[0]->ne[1]; // Number of heads + int64_t C = dst->ne[0]; // Total channels + int64_t D = C / H; // Dimensionality per head + int64_t L = T / B; // Sequence length per batch + + // 设置张量维度和步长信息 + int64_t ne_qkg[2] = {1, D}; // k/g的形状 [1,D] + int64_t ne_s[2] = {D, D}; // 状态张量形状 [D,D] + int64_t ne_vo[2] = {D, 1}; // v的形状 [D,1] + int64_t ne_q[1] = {D}; // q/o的形状 [D] + + // 计算步长(内存布局) + size_t nb_base = ggml_type_size(k->type); + size_t nb_qkg[2] = {nb_base, nb_base}; + size_t nb_s[2] = {nb_base, D * nb_base}; + size_t nb_vo[2] = {nb_base, D * nb_base}; + size_t nb_q[1] = {nb_base}; + + // 获取缩放因子 + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + // 预分配缓冲区,避免在循环中重复分配(性能优化1) + size_t buf_size = D * D * sizeof(float); + ggml_cann_pool_alloc state_buf1(ctx.pool(), buf_size); + void* buf1_ptr = state_buf1.get(); + ggml_cann_pool_alloc state_buf2(ctx.pool(), buf_size); + void* buf2_ptr = state_buf2.get(); + + // 创建可重用的缓冲区张量(性能优化2) + aclTensor* acl_buf_k = ggml_cann_create_tensor(buf1_ptr, ggml_cann_type_mapping(k->type), + ggml_type_size(k->type), ne_s, nb_s, 2); + aclTensor* acl_buf_v = ggml_cann_create_tensor(buf2_ptr, ggml_cann_type_mapping(k->type), + ggml_type_size(k->type), ne_s, nb_s, 2); + + // 预创建重复参数数组(性能优化3) + int64_t k_rep[2] = {1, D}; // k/g重复模式 [1,D] -> [D,D] + int64_t v_rep[2] = {D, 1}; // v重复模式 [D,1] -> [D,D] + aclIntArray* acl_k_rep = aclCreateIntArray(k_rep, 2); + aclIntArray* acl_v_rep = aclCreateIntArray(v_rep, 2); + + // 定义转置维度 + int64_t newdim[2] = {1, 0}; // [D,D] -> [D,D] (转置) + + // 遍历批次、头和时间步 + for (int64_t b = 0; b < B; b++) { + for (int64_t h = 0; h < H; h++) { + // 计算状态张量的偏移量 + size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base; + + // 创建状态张量 + aclTensor* acl_s = ggml_cann_create_tensor(s, ne_s, nb_s, 2, ACL_FORMAT_ND, s_offset); + aclTensor* acl_s_new = ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, + (B * L * H * D) * nb_base + s_offset); + + // 复制初始状态 + cann_copy(ctx, acl_s, acl_s_new); + + // 遍历时间步,更新状态并计算输出 + for (int64_t l = 0; l < L; l++) { + // 计算当前时间步的qkvgo偏移量 + size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base; + + // 创建当前时间步所需的张量 + aclTensor* acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + aclTensor* acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + aclTensor* acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + aclTensor* acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset); + aclTensor* acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + + // 1. 计算k*v外积 + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_k, acl_k_rep, acl_buf_k); // k广播到[D,D] + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_v, acl_v_rep, acl_buf_v); // v广播到[D,D] + aclnn_mul(ctx, acl_buf_k, acl_buf_v, nullptr); // 元素级乘法 k*v + + // 2. 应用门控并更新状态 + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_g, acl_k_rep, acl_buf_v); // g广播到[D,D] + aclnn_mul(ctx, acl_s_new, acl_buf_v, nullptr); // 门控操作: s = s * g + aclnn_add(ctx, acl_s_new, acl_buf_k, nullptr); // 状态更新: s = s + k*v + + // 3. 计算输出 + aclnn_permute(ctx, acl_s_new, acl_buf_k, newdim, 2); // 转置状态矩阵 + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_buf_k, acl_q, acl_o, 1); // 矩阵向量乘法: o = s^T * q + aclnn_muls(ctx, acl_o, scale, nullptr, true); // 应用缩放因子 + + // 释放当前时间步的临时张量 + ggml_cann_release_resources(ctx, acl_q, acl_k, acl_v, acl_o, acl_g); + } + + // 释放状态张量 + ggml_cann_release_resources(ctx, acl_s, acl_s_new); + } + } + + // 释放预分配的资源 + ggml_cann_release_resources(ctx, acl_buf_k, acl_buf_v, acl_k_rep, acl_v_rep); +} + void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 5c510cc9932..6cc830b0eea 100755 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -187,6 +187,18 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Computes the Gated Linear Attention for a ggml tensor using the CANN + * backend. + * + * @details ... + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the normalized values will be stored. + * @attention ... + */ +void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Computes the Group Normalization for a ggml tensor using the CANN * backend. @@ -605,6 +617,10 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst); +static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst); +static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst, int64_t* new_dim, uint64_t dims); + /** * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one * output tensor. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cb8af42ebf9..016fbe44ca1 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1881,6 +1881,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_cann_gated_linear_attn(ctx, dst); + break; default: return false; } @@ -2493,6 +2496,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_MEAN: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: + case GGML_OP_GATED_LINEAR_ATTN: return true; case GGML_OP_SCALE: float bias;