diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index bc33b99d96e..d39e8463ac3 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3184,6 +3184,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_tensor* src1 = dst->src[1]; // k, fp16 ggml_tensor* src2 = dst->src[2]; // v, fp16 ggml_tensor* src3 = dst->src[3]; // mask, fp16 + ggml_tensor* src4 = dst->src[4]; // sinks float maxBias = 0.0f; float scaleValue = 1.0f; @@ -3192,67 +3193,274 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); - if(logitSoftcap == 0.0f){ + { size_t faElemSize = sizeof(uint16_t); - auto faDataType = ACL_FLOAT16; //ACL_BF16; + auto faDataType = (src0->type == GGML_TYPE_BF16) ? ACL_BF16 : ACL_FLOAT16; aclTensor* acl_src0_f16_tensor = nullptr; aclTensor* acl_src1_f16_tensor = nullptr; aclTensor* acl_src2_f16_tensor = nullptr; aclTensor* acl_dst_f16_tensor = nullptr; + // Helper to ensure tensor is FP16/BF16, dequantizing if necessary + auto ensure_tensor_dtype = [&](ggml_tensor* src, ggml_cann_pool_alloc& allocator) -> aclTensor* { + if (ggml_cann_type_mapping(src->type) == faDataType) { + return ggml_cann_create_tensor(src); + } + + if (ggml_is_quantized(src->type)) { + size_t ne = ggml_nelements(src); + + // Correctly handle quantized types where ne[0] might not be a multiple of block size + size_t blck_size = ggml_blck_size(src->type); + size_t type_size = ggml_type_size(src->type); + + // Calculate row size in bytes, rounding up to full blocks + size_t row_size = ((src->ne[0] + blck_size - 1) / blck_size) * type_size; + + // Total bytes = row_size * number of rows + size_t n_rows = src->ne[1] * src->ne[2] * src->ne[3]; + size_t nbytes = row_size * n_rows; + + // Ensure host_f32 buffer is large enough for full blocks to avoid overflow during dequantization + // We need to handle padding per row + int64_t ne0 = src->ne[0]; + int64_t ne0_padded = ((ne0 + blck_size - 1) / blck_size) * blck_size; + size_t total_ne_padded = ne0_padded * n_rows; + + std::vector host_q(nbytes); + std::vector host_f32_padded(total_ne_padded); + + // Handle Data Copy (Contiguous vs Non-Contiguous) + // Synchronize stream to ensure data is ready on device before copying to host + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + if (ggml_is_contiguous(src)) { + // If contiguous, we can copy everything at once, BUT we must be careful if ggml_nbytes(src) is smaller than our calculated nbytes + // ggml_nbytes might round down. Our nbytes rounds up. + // if the tensor on device is actually smaller (allocated with rounded down size), we might read out of bounds. + ACL_CHECK(aclrtMemcpy(host_q.data(), nbytes, src->data, nbytes, ACL_MEMCPY_DEVICE_TO_HOST)); + } else { + // Handle non-contiguous copy. + bool dim1_contiguous = (src->nb[1] == row_size); // Check if stride matches our calculated packed row size + + // Actually, for quantized types, src->nb[1] SHOULD be the row size. + // Let's assume src->nb[1] is at least row_size. + + if (dim1_contiguous) { + // Optimization: Dim 0 and 1 are contiguous (Common case) + size_t block_size = row_size * src->ne[1]; + + for (int64_t i3 = 0; i3 < src->ne[3]; ++i3) { + for (int64_t i2 = 0; i2 < src->ne[2]; ++i2) { + size_t src_offset = i3 * src->nb[3] + i2 * src->nb[2]; + size_t dst_offset = (i3 * src->ne[2] + i2) * block_size; + + void* src_ptr = (char*)src->data + src_offset; + void* dst_ptr = (char*)host_q.data() + dst_offset; + + ACL_CHECK(aclrtMemcpy(dst_ptr, block_size, src_ptr, block_size, ACL_MEMCPY_DEVICE_TO_HOST)); + } + } + } else { + // Fallback: Copy row by row (Slow but safe for arbitrary strides) + for (int64_t i3 = 0; i3 < src->ne[3]; ++i3) { + for (int64_t i2 = 0; i2 < src->ne[2]; ++i2) { + for (int64_t i1 = 0; i1 < src->ne[1]; ++i1) { + size_t src_offset = i3 * src->nb[3] + i2 * src->nb[2] + i1 * src->nb[1]; + size_t dst_offset = ((i3 * src->ne[2] + i2) * src->ne[1] + i1) * row_size; + + void* src_ptr = (char*)src->data + src_offset; + void* dst_ptr = (char*)host_q.data() + dst_offset; + + ACL_CHECK(aclrtMemcpy(dst_ptr, row_size, src_ptr, row_size, ACL_MEMCPY_DEVICE_TO_HOST)); + } + } + } + } + } + + // ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + const auto * type_traits = ggml_get_type_traits(src->type); + // Use ne_padded to ensure we process full blocks + type_traits->to_float(host_q.data(), host_f32_padded.data(), total_ne_padded); + + // Sanitize data to remove NaNs/Infs from uninitialized/garbage memory + // Also clamp values to FP16 range to prevent Infinity generation during cast + float* f32_data = host_f32_padded.data(); + const float MAX_F16 = 65504.0f; + for (size_t i = 0; i < total_ne_padded; ++i) { + float val = f32_data[i]; + if (!std::isfinite(val)) { + f32_data[i] = 0.0f; + } else if (val > MAX_F16) { + f32_data[i] = MAX_F16; + } else if (val < -MAX_F16) { + f32_data[i] = -MAX_F16; + } + } + + // Repack if necessary (remove padding) + void* host_f32_ptr = host_f32_padded.data(); + std::vector host_f32_dense; + + if (ne0 != ne0_padded) { + host_f32_dense.resize(ne); + for (size_t i = 0; i < n_rows; ++i) { + memcpy(host_f32_dense.data() + i * ne0, host_f32_padded.data() + i * ne0_padded, ne0 * sizeof(float)); + } + host_f32_ptr = host_f32_dense.data(); + } + + // 1. Alloc Target Tensor (faDataType) using OUTPUT allocator + // Allocating output first ensures it stays on the stack bottom + void* device_target_buf = allocator.alloc(ne * faElemSize); + + // 2. Alloc Device FP32 buffer using TEMP allocator + ggml_cann_pool_alloc temp_alloc(ctx.pool()); + void* device_f32_buf = temp_alloc.alloc(ne * sizeof(float)); + // Use synchronous copy because host_f32 will be destroyed when this scope ends + // Only copy the actual valid elements (ne) + ACL_CHECK(aclrtMemcpy(device_f32_buf, ne * sizeof(float), host_f32_ptr, ne * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE)); + + int64_t* ne_tensor = src->ne; + size_t nb_f32[GGML_MAX_DIMS]; + nb_f32[0] = sizeof(float); + for(int i = 1; i < GGML_MAX_DIMS; ++i) nb_f32[i] = nb_f32[i - 1] * ne_tensor[i - 1]; + + aclTensor* acl_src_f32 = ggml_cann_create_tensor(device_f32_buf, ACL_FLOAT, sizeof(float), ne_tensor, nb_f32, GGML_MAX_DIMS); + + size_t nb_target[GGML_MAX_DIMS]; + nb_target[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i) nb_target[i] = nb_target[i - 1] * ne_tensor[i - 1]; + + aclTensor* acl_ret = ggml_cann_create_tensor(device_target_buf, faDataType, faElemSize, ne_tensor, nb_target, GGML_MAX_DIMS); + + // Cast FP32 -> faDataType + aclnn_cast(ctx, acl_src_f32, acl_ret, faDataType); + + // Synchronize to ensure cast is complete before temp_alloc is released and reused + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + ggml_cann_release_resources(ctx, acl_src_f32); + + return acl_ret; + } + + // Fallback for F32 or other non-quantized types that need casting + aclTensor* acl_src_raw = ggml_cann_create_tensor(src); + void* buf = allocator.alloc(ggml_nelements(src) * faElemSize); + int64_t* ne = src->ne; + size_t nb[GGML_MAX_DIMS]; + nb[0] = faElemSize; + for(int i=1; itype) != faDataType){ - aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); - src0_f16_buffer = src0_f16_allocator.alloc( - ggml_nelements(src0) * faElemSize); + // Step 2: create the acl tensors for src1 (Key), src2 (Value), + // and the direct output from FusedInferAttention - int64_t* src0_f16_ne = src0->ne; - size_t src0_f16_nb[GGML_MAX_DIMS]; - src0_f16_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; + // Check for Head Size Mismatch and Pad if necessary + int64_t head_dim_q = src0->ne[0]; + int64_t head_dim_k = src1->ne[0]; + int64_t head_dim_v = src2->ne[0]; + int64_t max_head_dim = head_dim_q; + if (head_dim_k > max_head_dim) max_head_dim = head_dim_k; + if (head_dim_v > max_head_dim) max_head_dim = head_dim_v; + + int64_t target_head_dim = (max_head_dim + 15) / 16 * 16; + + // Helper lambda for padding + auto pad_to_max_dim = [&](ggml_tensor* src, int64_t current_dim, int64_t target_dim, ggml_cann_pool_alloc& allocator) -> aclTensor* { + if (current_dim == target_dim) { + return ensure_tensor_dtype(src, allocator); + } else { + // 1. Alloc buffer for padded tensor using OUTPUT allocator FIRST + int64_t padded_ne[GGML_MAX_DIMS]; + memcpy(padded_ne, src->ne, sizeof(int64_t) * GGML_MAX_DIMS); + padded_ne[0] = target_dim; + + size_t padded_elements = padded_ne[0] * padded_ne[1] * padded_ne[2] * padded_ne[3]; + void* padded_buf = allocator.alloc(padded_elements * faElemSize); + + // Init with 0 + ACL_CHECK(aclrtMemsetAsync(padded_buf, padded_elements * faElemSize, 0, padded_elements * faElemSize, ctx.stream())); + + // 2. Need padding. Use temp allocator for intermediate tensor. + ggml_cann_pool_alloc temp_alloc(ctx.pool()); + aclTensor* acl_src_f16 = ensure_tensor_dtype(src, temp_alloc); + + // Create padded tensor descriptor + size_t padded_nb[GGML_MAX_DIMS]; + padded_nb[0] = faElemSize; + for(int i=1; ine, padded_nb, GGML_MAX_DIMS); + + // Copy + aclnn_cast(ctx, acl_src_f16, acl_padded_view, faDataType); + + // Synchronize to ensure copy is complete before temp_alloc is released + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + ggml_cann_release_resources(ctx, acl_src_f16); + ggml_cann_release_resources(ctx, acl_padded_view); + + return acl_padded; } + }; - acl_src0_f16_tensor = ggml_cann_create_tensor( - src0_f16_buffer, faDataType, faElemSize, - src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS - ); - aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); - ggml_cann_release_resources(ctx, acl_src0_f32_tensor); - }else{ - acl_src0_f16_tensor = ggml_cann_create_tensor(src0); - } + // Step 1: cast the src0 (Query) to fp16 if needed + ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); + acl_src0_f16_tensor = pad_to_max_dim(src0, head_dim_q, target_head_dim, src0_f16_allocator); // Step 2: create the acl tensors for src1 (Key), src2 (Value), // and the direct output from FusedInferAttention - acl_src1_f16_tensor = ggml_cann_create_tensor(src1); - acl_src2_f16_tensor = ggml_cann_create_tensor(src2); + ggml_cann_pool_alloc src1_f16_allocator(ctx.pool()); + acl_src1_f16_tensor = pad_to_max_dim(src1, head_dim_k, target_head_dim, src1_f16_allocator); - ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - void* out_f16_buffer = out_f16_allocator.alloc( - ggml_nelements(dst) * faElemSize); + ggml_cann_pool_alloc src2_f16_allocator(ctx.pool()); + acl_src2_f16_tensor = pad_to_max_dim(src2, head_dim_v, target_head_dim, src2_f16_allocator); - int64_t* out_f16_ne = src0->ne; - size_t out_f16_nb[GGML_MAX_DIMS]; - out_f16_nb[0] = faElemSize; + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + + int64_t out_padded_ne[GGML_MAX_DIMS]; + memcpy(out_padded_ne, src0->ne, sizeof(int64_t) * GGML_MAX_DIMS); + out_padded_ne[0] = target_head_dim; // The op will output target_head_dim + + size_t out_padded_elements = out_padded_ne[0] * out_padded_ne[1] * out_padded_ne[2] * out_padded_ne[3]; + void* out_f16_buffer = out_f16_allocator.alloc(out_padded_elements * faElemSize); + + size_t out_padded_nb[GGML_MAX_DIMS]; + out_padded_nb[0] = faElemSize; for(int i = 1; i < GGML_MAX_DIMS; ++i){ - out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + out_padded_nb[i] = out_padded_nb[i - 1] * out_padded_ne[i - 1]; } acl_dst_f16_tensor = ggml_cann_create_tensor( out_f16_buffer, faDataType, faElemSize, - out_f16_ne, out_f16_nb, GGML_MAX_DIMS + out_padded_ne, out_padded_nb, GGML_MAX_DIMS ); + // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp aclTensor* bcast_pse_tensor = nullptr; + aclTensor* acl_atten_mask_tensor = nullptr; int64_t bcast_pse_ne[GGML_MAX_DIMS]; size_t bcast_pse_nb[GGML_MAX_DIMS]; ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); @@ -3336,60 +3544,340 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ } } - // Step 4: set the inputs for FusedInferAttention. - int kvTensorNum = 1; - aclTensor* acl_q_tensor = acl_src0_f16_tensor; - aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; - aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; - auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); - auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); - - int64_t numHeads = src0->ne[2]; // N - int64_t numKeyValueHeads = src1->ne[2]; - // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) - int64_t preTokens = 65535; - int64_t nextTokens = 65535; - char layout[5] = {'B', 'N', 'S', 'D', 0}; - int64_t sparseMode = 0; - int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; - int64_t blockSize = 0; - int64_t antiquantMode = 0; - bool softmaxLseFlag = false; - int64_t keyAntiquantMode = 0; - int64_t valueAntiquantMode = 0; - - // Step 5: launch the FusedInferAttentionScoreV2 kernel. - // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md - - GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, - acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v - bcast_pse_tensor, nullptr, // pse, mask - nullptr, nullptr, // actSeqLen, actSeqLenkv - nullptr, nullptr, // deqScale1, quantScale1 - nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 - nullptr, nullptr, // antiquantScale, antiquantOffset - nullptr, // blockTable - nullptr, nullptr, // qPadSize, kvPadSize - nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset - nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset - nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen - numHeads, scaleValue, // heads, scaleValue - preTokens, nextTokens, // preTokens, nextTokens - layout, // inputLayout - numKeyValueHeads, // numKVHeads - sparseMode, innerPrecise, // sparseMode, innerPrecise - blockSize, antiquantMode, // blockSize, antiquantMode - softmaxLseFlag, // softmaxLseFlag - keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode - acl_dst_f16_tensor, // attentionOut - nullptr // softmaxLse - ); + if (src4 != nullptr) { + ggml_cann_pool_alloc atten_mask_allocator(ctx.pool()); + aclTensor* acl_src4_tensor = ggml_cann_create_tensor(src4); + + int64_t valid_len = src4->ne[0]; + int64_t full_len = src1->ne[1]; + int64_t target_len = (valid_len < full_len) ? full_len : valid_len; + + size_t mask_size = target_len * sizeof(int8_t); + void* mask_buffer = atten_mask_allocator.alloc(mask_size); + + if (valid_len < full_len) { + ACL_CHECK(aclrtMemsetAsync(mask_buffer, mask_size, 1, mask_size, ctx.stream())); + } + + int64_t partial_ne[GGML_MAX_DIMS] = {valid_len, 1, 1, 1}; + size_t partial_nb[GGML_MAX_DIMS]; + partial_nb[0] = sizeof(int8_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + partial_nb[i] = partial_nb[i - 1] * partial_ne[i - 1]; + } + + aclTensor* acl_partial_mask_tensor = ggml_cann_create_tensor( + mask_buffer, ACL_BOOL, sizeof(int8_t), + partial_ne, partial_nb, GGML_MAX_DIMS + ); + + aclnn_cast(ctx, acl_src4_tensor, acl_partial_mask_tensor, ACL_BOOL); + ggml_cann_release_resources(ctx, acl_partial_mask_tensor); + ggml_cann_release_resources(ctx, acl_src4_tensor); + + int64_t full_ne[GGML_MAX_DIMS] = {target_len, src0->ne[1], src0->ne[2], src0->ne[3]}; + size_t full_nb[GGML_MAX_DIMS] = {sizeof(int8_t), 0, 0, 0}; + + acl_atten_mask_tensor = ggml_cann_create_tensor( + mask_buffer, ACL_BOOL, sizeof(int8_t), + full_ne, full_nb, GGML_MAX_DIMS + ); + } + + // Step 4: Dispatch to FusedInferAttentionScoreV2 or Fallback (if head_dim > 512) + if (target_head_dim > 512) { + // Fallback implementation using primitive ops with FP32 precision + int64_t permute_dims_k[] = {0, 1, 3, 2}; + aclIntArray* permute_dims_k_array = aclCreateIntArray(permute_dims_k, 4); + + auto permute_tensor = [&](aclTensor* src, int64_t* ne, aclIntArray* dims, int64_t* dims_val, void** out_buf) -> aclTensor* { + // ne is [D, N, H, B] (padded) + // Calculates new_ne for ggml_cann_create_tensor (which reverses it to get logical shape) + + int64_t new_ne[4]; + // Logical Output[i] = Logical Input[dims[i]] + // Logical Input[k] = ne[3-k] + // So Logical Output[i] = ne[3 - dims[i]] + // new_ne[j] = Logical Output[3-j] = ne[3 - dims[3-j]] + + for(int i=0; i<4; ++i) new_ne[i] = ne[3 - dims_val[3 - i]]; + + ggml_cann_pool_alloc temp_alloc(ctx.pool()); + size_t numel = new_ne[0]*new_ne[1]*new_ne[2]*new_ne[3]; + void* buf = temp_alloc.alloc(numel * faElemSize); + *out_buf = buf; + + // Strides for contiguous compact tensor + size_t new_nb[4]; + new_nb[3] = faElemSize; + new_nb[2] = new_nb[3] * new_ne[3]; + new_nb[1] = new_nb[2] * new_ne[2]; + new_nb[0] = new_nb[1] * new_ne[1]; + + aclTensor* dst = ggml_cann_create_tensor(buf, faDataType, faElemSize, new_ne, new_nb, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, Permute, src, dims, dst); + return dst; + }; + + auto create_3d_tensor = [&](void* ptr, int64_t* ne_4d, aclDataType dtype, size_t dsize) -> aclTensor* { + // Reshape [D, N, H, B] -> [D, N, H*B] (GGML) + // Logical: [B, H, N, D] -> [B*H, N, D] + int64_t ne_3d[] = {ne_4d[0], ne_4d[1], ne_4d[2] * ne_4d[3]}; + size_t nb_3d[3]; + nb_3d[0] = dsize; + nb_3d[1] = nb_3d[0] * ne_3d[0]; + nb_3d[2] = nb_3d[1] * ne_3d[1]; + return ggml_cann_create_tensor(ptr, dtype, dsize, ne_3d, nb_3d, 3); + }; + + auto cast_to_f32 = [&](aclTensor* src, int64_t* ne_4d, void** out_buf) -> aclTensor* { + int64_t ne_3d[] = {ne_4d[0], ne_4d[1], ne_4d[2] * ne_4d[3]}; + size_t numel = ne_3d[0] * ne_3d[1] * ne_3d[2]; + + ggml_cann_pool_alloc temp_alloc(ctx.pool()); + void* buf = temp_alloc.alloc(numel * sizeof(float)); + *out_buf = buf; + + aclTensor* dst = create_3d_tensor(buf, ne_4d, ACL_FLOAT, sizeof(float)); + aclnn_cast(ctx, src, dst, ACL_FLOAT); + return dst; + }; + + int64_t q_ne[] = {target_head_dim, src0->ne[1], src0->ne[2], src0->ne[3]}; + int64_t k_ne[] = {target_head_dim, src1->ne[1], src1->ne[2], src1->ne[3]}; + int64_t v_ne[] = {target_head_dim, src2->ne[1], src2->ne[2], src2->ne[3]}; + + void* Q_ptr = src0_f16_allocator.get() ? src0_f16_allocator.get() : src0->data; + aclTensor* Q_3d_f16 = create_3d_tensor(Q_ptr, q_ne, faDataType, faElemSize); + + void* K_T_ptr = nullptr; + aclTensor* K_T = permute_tensor(acl_src1_f16_tensor, k_ne, permute_dims_k_array, permute_dims_k, &K_T_ptr); + + int64_t k_t_ne[] = {k_ne[1], k_ne[0], k_ne[2], k_ne[3]}; + + // Handle GQA: Repeat K and V if n_head_q > n_head_kv + int64_t n_head_q = src0->ne[2]; + int64_t n_head_kv = src1->ne[2]; + int64_t n_rep = n_head_q / n_head_kv; + + aclTensor* K_T_rep = nullptr; + aclTensor* V_rep = nullptr; + void* v_rep_ptr_val = nullptr; + + if (n_rep > 1) { + auto repeat_tensor = [&](aclTensor* src, int64_t* src_ne, void** out_buf) -> aclTensor* { + int64_t dst_ne[] = {src_ne[0], src_ne[1], src_ne[2] * n_rep, src_ne[3]}; + + ggml_cann_pool_alloc temp_alloc(ctx.pool()); + size_t numel = dst_ne[0]*dst_ne[1]*dst_ne[2]*dst_ne[3]; + void* buf = temp_alloc.alloc(numel * faElemSize); + *out_buf = buf; + + size_t dst_nb[4]; + dst_nb[3] = faElemSize; + dst_nb[2] = dst_nb[3] * dst_ne[3]; + dst_nb[1] = dst_nb[2] * dst_ne[2]; + dst_nb[0] = dst_nb[1] * dst_ne[1]; + + aclTensor* dst = ggml_cann_create_tensor(buf, faDataType, faElemSize, dst_ne, dst_nb, 4); + + // Logical Repeats: [1, n_rep, 1, 1] (B, H, N, D) -> Repeat H at index 1 + int64_t repeats_logical[] = {1, n_rep, 1, 1}; + aclIntArray* repeat_array = aclCreateIntArray(repeats_logical, 4); + + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, src, repeat_array, dst); + + aclDestroyIntArray(repeat_array); + return dst; + }; + + void* k_t_rep_ptr = nullptr; + K_T_rep = repeat_tensor(K_T, k_t_ne, &k_t_rep_ptr); + // Update pointers and dimensions for 3D creation + K_T_ptr = k_t_rep_ptr; + k_t_ne[2] *= n_rep; + + void* v_ptr = src2_f16_allocator.get() ? src2_f16_allocator.get() : src2->data; + // v_ne is [D, M, H_kv, B] + // Reuse acl_src2_f16_tensor as input for repeat + void* v_rep_ptr = nullptr; + V_rep = repeat_tensor(acl_src2_f16_tensor, v_ne, &v_rep_ptr); + + // Update V info + v_rep_ptr_val = v_rep_ptr; + v_ne[2] *= n_rep; + } + + aclTensor* K_T_3d_f16 = create_3d_tensor(K_T_ptr, k_t_ne, faDataType, faElemSize); + + void* final_v_ptr = (V_rep) ? v_rep_ptr_val : (src2_f16_allocator.get() ? src2_f16_allocator.get() : src2->data); + aclTensor* V_3d_f16 = create_3d_tensor(final_v_ptr, v_ne, faDataType, faElemSize); + + // Cast inputs to FP32 + void* Q_f32_buf = nullptr; + aclTensor* Q_3d_f32 = cast_to_f32(Q_3d_f16, q_ne, &Q_f32_buf); + + void* K_T_f32_buf = nullptr; + aclTensor* K_T_3d_f32 = cast_to_f32(K_T_3d_f16, k_t_ne, &K_T_f32_buf); + + void* V_f32_buf = nullptr; + aclTensor* V_3d_f32 = cast_to_f32(V_3d_f16, v_ne, &V_f32_buf); + + + // Score [B, H, N, M] -> [B*H, N, M] + // In GGML order: [M, N, H, B] -> [M, N, H*B] + int64_t score_ne[] = {k_ne[1], q_ne[1], q_ne[2], q_ne[3]}; // M, N, H, B + size_t score_nb[4]; + score_nb[3] = sizeof(float); + score_nb[2] = score_nb[3] * score_ne[3]; + score_nb[1] = score_nb[2] * score_ne[2]; + score_nb[0] = score_nb[1] * score_ne[1]; + + ggml_cann_pool_alloc score_alloc(ctx.pool()); + void* score_buf = score_alloc.alloc(score_ne[0]*score_ne[1]*score_ne[2]*score_ne[3] * sizeof(float)); + aclTensor* Score_3d_f32 = create_3d_tensor(score_buf, score_ne, ACL_FLOAT, sizeof(float)); + + int8_t cubeMathType = 0; + GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, Q_3d_f32, K_T_3d_f32, Score_3d_f32, cubeMathType); + + // Scale + if (logitSoftcap != 0.0f) { + scaleValue /= logitSoftcap; + } + aclScalar* acl_scale = aclCreateScalar(&scaleValue, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, Score_3d_f32, acl_scale); + aclDestroyScalar(acl_scale); + + // Mask + if (bcast_pse_tensor) { + // bcast_pse_tensor is [M, N, H, B]. + // Need 3D view [M, N, H*B]. + // Use score_ne to ensure shape matches Score (N dimension depends on src0->ne[1]) + int64_t* mask_ne = score_ne; + void* mask_ptr = bcast_pse_buffer; + + // Mask is usually ACL_FLOAT16, need cast to ACL_FLOAT + aclTensor* Mask_3d_f16 = create_3d_tensor(mask_ptr, mask_ne, ACL_FLOAT16, sizeof(uint16_t)); + + void* mask_f32_buf = nullptr; + aclTensor* Mask_3d_f32 = cast_to_f32(Mask_3d_f16, mask_ne, &mask_f32_buf); + + float one = 1.0f; + aclScalar* alpha = aclCreateScalar(&one, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, Score_3d_f32, Mask_3d_f32, alpha); + aclDestroyScalar(alpha); + ggml_cann_release_resources(ctx, Mask_3d_f16, Mask_3d_f32); + } + + // Softcap + if (logitSoftcap != 0.0f) { + GGML_CANN_CALL_ACLNN_OP(ctx, Tanh, Score_3d_f32, Score_3d_f32); + aclScalar* acl_cap = aclCreateScalar(&logitSoftcap, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, Score_3d_f32, acl_cap); + aclDestroyScalar(acl_cap); + } + + // Softmax + GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, Score_3d_f32, (int64_t)-1, Score_3d_f32); + + // Out [B, H, N, D] -> [B*H, N, D] + // We compute directly into acl_dst_f16_tensor buffer + ggml_cann_pool_alloc out_f32_allocator(ctx.pool()); + size_t out_elements = out_padded_ne[0] * out_padded_ne[1] * out_padded_ne[2] * out_padded_ne[3]; + void* out_f32_buf = out_f32_allocator.alloc(out_elements * sizeof(float)); + + aclTensor* Out_3d_f32 = create_3d_tensor(out_f32_buf, out_padded_ne, ACL_FLOAT, sizeof(float)); + + GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, Score_3d_f32, V_3d_f32, Out_3d_f32, cubeMathType); + + // Cast Output FP32 -> faDataType + void* out_ptr = out_f16_buffer; // This buffer is always allocated/available + aclTensor* Out_3d_f16 = create_3d_tensor(out_ptr, out_padded_ne, faDataType, faElemSize); + aclnn_cast(ctx, Out_3d_f32, Out_3d_f16, faDataType); + + aclDestroyIntArray(permute_dims_k_array); + ggml_cann_release_resources(ctx, K_T, Q_3d_f16, K_T_3d_f16, V_3d_f16, Score_3d_f32, Out_3d_f32, Out_3d_f16, Q_3d_f32, K_T_3d_f32, V_3d_f32); + if (K_T_rep) ggml_cann_release_resources(ctx, K_T_rep); + if (V_rep) ggml_cann_release_resources(ctx, V_rep); + + } else { + // Step 4: set the inputs for FusedInferAttention. + int kvTensorNum = 1; + aclTensor* acl_q_tensor = acl_src0_f16_tensor; + aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; + aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; + auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + + int64_t numHeads = src0->ne[2]; // N + int64_t numKeyValueHeads = src1->ne[2]; + // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) + int64_t preTokens = 65535; + int64_t nextTokens = 65535; + char layout[5] = {'B', 'N', 'S', 'D', 0}; + int64_t sparseMode = 0; + int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; + int64_t blockSize = 0; + int64_t antiquantMode = 0; + bool softmaxLseFlag = false; + int64_t keyAntiquantMode = 0; + int64_t valueAntiquantMode = 0; + + // Step 5: launch the FusedInferAttentionScoreV2 kernel. + // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + + + + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, + acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v + bcast_pse_tensor, nullptr, // pse, mask + nullptr, nullptr, // actSeqLen, actSeqLenkv + nullptr, nullptr, // deqScale1, quantScale1 + nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // qPadSize, kvPadSize + nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset + nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset + nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen + numHeads, scaleValue, // heads, scaleValue + preTokens, nextTokens, // preTokens, nextTokens + layout, // inputLayout + numKeyValueHeads, // numKVHeads + sparseMode, innerPrecise, // sparseMode, innerPrecise + blockSize, antiquantMode, // blockSize, antiquantMode + softmaxLseFlag, // softmaxLseFlag + keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode + acl_dst_f16_tensor, // attentionOut + nullptr // softmaxLse + ); + + + } // Step 6: post-processing, permute and cast to f32 int64_t new_dim[] = {0, 2, 1, 3}; aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + + + aclTensor* acl_result_to_permute = acl_dst_f16_tensor; + aclTensor* acl_sliced_view = nullptr; + + if (dst->ne[0] != target_head_dim) { + int64_t fake_ggml_ne[4] = {dst->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]}; + + acl_sliced_view = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + fake_ggml_ne, out_padded_nb, GGML_MAX_DIMS + ); + + acl_result_to_permute = acl_sliced_view; + } + + bool need_permute = (dst->ne[1] != src0->ne[1]); + if(ggml_cann_type_mapping(dst->type) != faDataType){ ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); @@ -3404,14 +3892,26 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( perm_out_f16_buffer, faDataType, faElemSize, perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); - aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); + + if (need_permute) { + aclnn_permute(ctx, acl_result_to_permute, acl_perm_out_f16_tensor, new_dim, 4); + } else { + aclnn_cast(ctx, acl_result_to_permute, acl_perm_out_f16_tensor, faDataType); + } + aclnn_cast(ctx, acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); }else{ - // only need to permute - aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); + if (need_permute) { + aclnn_permute(ctx, acl_result_to_permute, acl_dst_tensor, new_dim, 4); + } else { + aclnn_cast(ctx, acl_result_to_permute, acl_dst_tensor, faDataType); + } } + + if (acl_sliced_view) ggml_cann_release_resources(ctx, acl_sliced_view); + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, @@ -3420,7 +3920,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ if(src3 != nullptr){ ggml_cann_release_resources(ctx, bcast_pse_tensor); } - }else{ - GGML_ABORT("Function is not implemented."); + if(src4 != nullptr && src3 == nullptr){ + ggml_cann_release_resources(ctx, bcast_pse_tensor); + } } }