|
51 | 51 | #include <aclnnop/aclnn_triu.h> |
52 | 52 | #include <aclnnop/aclnn_upsample_nearest_2d.h> |
53 | 53 | #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h> |
| 54 | +#include <aclnnop/aclnn_argmax.h> |
54 | 55 | #include <float.h> |
55 | 56 |
|
56 | 57 | #include <cmath> |
@@ -3440,3 +3441,46 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { |
3440 | 3441 | ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); |
3441 | 3442 | ACL_CHECK(aclDestroyTensor(acl_dst)); |
3442 | 3443 | } |
| 3444 | + |
| 3445 | + |
| 3446 | + void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){ |
| 3447 | + ggml_tensor * src0 = dst->src[0]; |
| 3448 | + |
| 3449 | + aclTensor* acl_src = ggml_cann_create_tensor(src0); |
| 3450 | + aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3); |
| 3451 | + |
| 3452 | + uint64_t workspaceSize = 0; |
| 3453 | + aclOpExecutor* executor; |
| 3454 | + void* workspaceAddr = nullptr; |
| 3455 | + |
| 3456 | + ACL_CHECK(aclnnArgMaxGetWorkspaceSize(acl_src, 3, false, acl_dst, |
| 3457 | + &workspaceSize, &executor)); |
| 3458 | + if (workspaceSize > 0) { |
| 3459 | + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); |
| 3460 | + workspaceAddr = workspace_allocator.get(); |
| 3461 | + } |
| 3462 | + ACL_CHECK(aclnnArgMax(workspaceAddr, workspaceSize, executor, ctx.stream())); |
| 3463 | + |
| 3464 | + ACL_CHECK(aclDestroyTensor(acl_src)); |
| 3465 | + ACL_CHECK(aclDestroyTensor(acl_dst)); |
| 3466 | +} |
| 3467 | + |
| 3468 | +void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst){ |
| 3469 | + ggml_tensor * src0 = dst->src[0]; |
| 3470 | + |
| 3471 | + aclTensor* acl_src = ggml_cann_create_tensor(src0); |
| 3472 | + aclTensor* acl_dst = ggml_cann_create_tensor(dst); |
| 3473 | + aclnn_cos(ctx, acl_src, acl_dst); |
| 3474 | + ACL_CHECK(aclDestroyTensor(acl_src)); |
| 3475 | + ACL_CHECK(aclDestroyTensor(acl_dst)); |
| 3476 | +} |
| 3477 | + |
| 3478 | +void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst){ |
| 3479 | + ggml_tensor * src0 = dst->src[0]; |
| 3480 | + |
| 3481 | + aclTensor* acl_src = ggml_cann_create_tensor(src0); |
| 3482 | + aclTensor* acl_dst = ggml_cann_create_tensor(dst); |
| 3483 | + aclnn_sin(ctx, acl_src, acl_dst); |
| 3484 | + ACL_CHECK(aclDestroyTensor(acl_src)); |
| 3485 | + ACL_CHECK(aclDestroyTensor(acl_dst)); |
| 3486 | +} |
0 commit comments