|
38 | 38 | #include <aclnnop/aclnn_matmul.h> |
39 | 39 | #include <aclnnop/aclnn_max_pool.h> |
40 | 40 | #include <aclnnop/aclnn_mm.h> |
| 41 | +#include <aclnnop/aclnn_mv.h> |
41 | 42 | #include <aclnnop/aclnn_permute.h> |
42 | 43 | #include <aclnnop/aclnn_pow_tensor_tensor.h> |
43 | 44 | #include <aclnnop/aclnn_reduce_sum.h> |
@@ -439,6 +440,93 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { |
439 | 440 | ggml_cann_release_resources(ctx, norm, acl_src, acl_dst); |
440 | 441 | } |
441 | 442 |
|
| 443 | +void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst) { |
| 444 | + ggml_tensor * k = dst->src[0]; |
| 445 | + ggml_tensor * v = dst->src[1]; |
| 446 | + ggml_tensor * q = dst->src[2]; |
| 447 | + ggml_tensor * g = dst->src[3]; |
| 448 | + ggml_tensor * s = dst->src[4]; |
| 449 | + |
| 450 | + int64_t B = dst->src[4]->ne[1]; |
| 451 | + int64_t T = dst->src[0]->ne[2]; |
| 452 | + int64_t H = dst->src[0]->ne[1]; |
| 453 | + int64_t C = dst->ne[0]; |
| 454 | + int64_t D = C / H; |
| 455 | + int64_t L = T / B; |
| 456 | + |
| 457 | + int64_t ne_qkg[2] = {1, D}; |
| 458 | + // int64_t ne_qkg[2] = {D, 1}; |
| 459 | + int64_t ne_s[2] = {D, D}; |
| 460 | + int64_t ne_vo[2] = {D, 1}; |
| 461 | + // int64_t ne_vo[2] = {1, D}; |
| 462 | + int64_t ne_q[1] = {D}; |
| 463 | + size_t nb_base = ggml_type_size(k->type); |
| 464 | + size_t nb_qkg[2] = {nb_base, nb_base}; |
| 465 | + size_t nb_s[2] = {nb_base, D * nb_base}; |
| 466 | + size_t nb_vo[2] = {nb_base, D * nb_base}; |
| 467 | + size_t nb_q[1] = {nb_base}; |
| 468 | + |
| 469 | + float scale; |
| 470 | + memcpy(&scale, dst->op_params, sizeof(float)); |
| 471 | + |
| 472 | + for (int64_t b = 0; b < B; b++) { |
| 473 | + for (int64_t h = 0; h < H; h++) { |
| 474 | + size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base; |
| 475 | + // D * D |
| 476 | + aclTensor* acl_s = ggml_cann_create_tensor(s, ne_s, nb_s, 2, ACL_FORMAT_ND, s_offset); |
| 477 | + aclTensor* acl_s_new = ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); |
| 478 | + cann_copy(ctx, acl_s, acl_s_new); |
| 479 | + for (int64_t l = 0; l < L; l++) { |
| 480 | + size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base; |
| 481 | + // D * 1 |
| 482 | + aclTensor* acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); |
| 483 | + aclTensor* acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); |
| 484 | + // D |
| 485 | + aclTensor* acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); |
| 486 | + // 1 * D |
| 487 | + aclTensor* acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset); |
| 488 | + // D |
| 489 | + aclTensor* acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); |
| 490 | + // repeat k and v |
| 491 | + // buffer for repeated k |
| 492 | + size_t buf_size = D * D * sizeof(float); |
| 493 | + ggml_cann_pool_alloc state_buf1(ctx.pool(), buf_size); |
| 494 | + void* buf1_ptr = state_buf1.get(); |
| 495 | + aclTensor* acl_buf_k = ggml_cann_create_tensor(buf1_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2); |
| 496 | + // buffer for repeated v |
| 497 | + ggml_cann_pool_alloc state_buf2(ctx.pool(), buf_size); |
| 498 | + void* buf2_ptr = state_buf2.get(); |
| 499 | + aclTensor* acl_buf_v = ggml_cann_create_tensor(buf2_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2); |
| 500 | + // repeat |
| 501 | + int64_t k_rep[2] = {1, D}; |
| 502 | + int64_t v_rep[2] = {D, 1}; |
| 503 | + // int64_t k_rep[2] = {D, 1}; |
| 504 | + // int64_t v_rep[2] = {1, D}; |
| 505 | + aclIntArray* acl_k_rep = aclCreateIntArray(k_rep, 2); |
| 506 | + aclIntArray* acl_v_rep = aclCreateIntArray(v_rep, 2); |
| 507 | + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_k, acl_k_rep, acl_buf_k); |
| 508 | + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_v, acl_v_rep, acl_buf_v); |
| 509 | + // inplace mul, saved in acl_buf_k |
| 510 | + aclnn_mul(ctx, acl_buf_k, acl_buf_v, nullptr); |
| 511 | + // apply g to s |
| 512 | + // reuse acl_buf_v to store repeated g |
| 513 | + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_g, acl_k_rep, acl_buf_v); |
| 514 | + aclnn_mul(ctx, acl_s_new, acl_buf_v, nullptr); |
| 515 | + // add kv |
| 516 | + aclnn_add(ctx, acl_s_new, acl_buf_k, nullptr); |
| 517 | + // compute output |
| 518 | + // permute state and store in acl_buf k |
| 519 | + int64_t newdim[2] = {1, 0}; |
| 520 | + aclnn_permute(ctx, acl_s_new, acl_buf_k, newdim, 2); |
| 521 | + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_buf_k, acl_q, acl_o, 1); |
| 522 | + aclnn_muls(ctx, acl_o, scale, nullptr, true); |
| 523 | + ggml_cann_release_resources(ctx, acl_q, acl_k, acl_v, acl_o, acl_g, acl_buf_k, acl_buf_v, acl_k_rep, acl_v_rep); |
| 524 | + } |
| 525 | + ggml_cann_release_resources(ctx, acl_s, acl_s_new); |
| 526 | + } |
| 527 | + } |
| 528 | +} |
| 529 | + |
442 | 530 | void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { |
443 | 531 | ggml_tensor* src = dst->src[0]; |
444 | 532 |
|
|
0 commit comments