|
51 | 51 | #include <aclnnop/aclnn_triu.h> |
52 | 52 | #include <aclnnop/aclnn_upsample_nearest_2d.h> |
53 | 53 | #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h> |
| 54 | +#include <aclnnop/aclnn_argmax.h> |
54 | 55 | #include <float.h> |
55 | 56 |
|
56 | 57 | #include <cmath> |
@@ -180,6 +181,55 @@ static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, |
180 | 181 | ACL_CHECK(aclDestroyScalar(alpha)); |
181 | 182 | } |
182 | 183 |
|
| 184 | +/** |
| 185 | + * @brief Computes the argmax of a tensor along the specified dimension using the CANN backend. |
| 186 | + * |
| 187 | + * This function performs the argmax operation on the input tensor (`acl_src`) |
| 188 | + * and stores the result in the destination tensor (`acl_dst`). The argmax is |
| 189 | + * computed along a specified axis, and the result is the index of the maximum value |
| 190 | + * along that axis. The operation is performed using the CANN backend, and |
| 191 | + * necessary memory allocation is handled automatically. |
| 192 | + * |
| 193 | + * @param ctx The context for CANN backend operations. |
| 194 | + * @param acl_src The source tensor on which the argmax operation will be performed. |
| 195 | + * @param acl_dst The destination tensor that will hold the resulting indices. |
| 196 | + * @param dst The destination tensor object that stores the result after the argmax operation. |
| 197 | + */ |
| 198 | +static void aclnn_argmax(ggml_backend_cann_context& ctx, aclTensor* acl_src, |
| 199 | + aclTensor* acl_dst, ggml_tensor* dst) { |
| 200 | + ggml_cann_pool_alloc dst_buffer_allocator( |
| 201 | + ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type)); |
| 202 | + void* buffer = dst_buffer_allocator.get(); |
| 203 | + int64_t dst_buffer_ne[4] = {1, dst->ne[0], dst->ne[1], dst->ne[2]}; |
| 204 | + size_t dst_buffer_nb[GGML_MAX_DIMS]; |
| 205 | + dst_buffer_nb[0] = ggml_type_size(dst->type); |
| 206 | + for (int i = 1; i < GGML_MAX_DIMS; i++) { |
| 207 | + dst_buffer_nb[i] = dst_buffer_nb[i - 1] * dst_buffer_ne[i - 1]; |
| 208 | + } |
| 209 | + |
| 210 | + aclTensor* dst_buffer_tensor = |
| 211 | + ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), |
| 212 | + dst_buffer_ne, dst_buffer_nb, 4); |
| 213 | + |
| 214 | + uint64_t workspaceSize = 0; |
| 215 | + aclOpExecutor* executor; |
| 216 | + void* workspaceAddr = nullptr; |
| 217 | + |
| 218 | + |
| 219 | + ACL_CHECK(aclnnArgMaxGetWorkspaceSize(acl_src, 3, true, dst_buffer_tensor, |
| 220 | + &workspaceSize, &executor)); |
| 221 | + if (workspaceSize > 0) { |
| 222 | + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); |
| 223 | + workspaceAddr = workspace_allocator.get(); |
| 224 | + } |
| 225 | + |
| 226 | + ACL_CHECK(aclnnArgMax(workspaceAddr, workspaceSize, executor, ctx.stream())); |
| 227 | + |
| 228 | + size_t cpy_size = ggml_nbytes(dst); |
| 229 | + ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, buffer, cpy_size, |
| 230 | + ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); |
| 231 | +} |
| 232 | + |
183 | 233 | void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) { |
184 | 234 | ggml_tensor* src0 = dst->src[0]; |
185 | 235 | ggml_tensor* src1 = dst->src[1]; |
@@ -3444,3 +3494,34 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { |
3444 | 3494 | ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); |
3445 | 3495 | ACL_CHECK(aclDestroyTensor(acl_dst)); |
3446 | 3496 | } |
| 3497 | + |
| 3498 | + |
| 3499 | + void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){ |
| 3500 | + ggml_tensor * src0 = dst->src[0]; |
| 3501 | + |
| 3502 | + aclTensor* acl_src = ggml_cann_create_tensor(src0); |
| 3503 | + aclTensor* acl_dst = ggml_cann_create_tensor(dst); |
| 3504 | + aclnn_argmax(ctx, acl_src, acl_dst, dst); |
| 3505 | + ACL_CHECK(aclDestroyTensor(acl_src)); |
| 3506 | + ACL_CHECK(aclDestroyTensor(acl_dst)); |
| 3507 | +} |
| 3508 | + |
| 3509 | +void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst){ |
| 3510 | + ggml_tensor * src0 = dst->src[0]; |
| 3511 | + |
| 3512 | + aclTensor* acl_src = ggml_cann_create_tensor(src0); |
| 3513 | + aclTensor* acl_dst = ggml_cann_create_tensor(dst); |
| 3514 | + aclnn_cos(ctx, acl_src, acl_dst); |
| 3515 | + ACL_CHECK(aclDestroyTensor(acl_src)); |
| 3516 | + ACL_CHECK(aclDestroyTensor(acl_dst)); |
| 3517 | +} |
| 3518 | + |
| 3519 | +void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst){ |
| 3520 | + ggml_tensor * src0 = dst->src[0]; |
| 3521 | + |
| 3522 | + aclTensor* acl_src = ggml_cann_create_tensor(src0); |
| 3523 | + aclTensor* acl_dst = ggml_cann_create_tensor(dst); |
| 3524 | + aclnn_sin(ctx, acl_src, acl_dst); |
| 3525 | + ACL_CHECK(aclDestroyTensor(acl_src)); |
| 3526 | + ACL_CHECK(aclDestroyTensor(acl_dst)); |
| 3527 | +} |
0 commit comments