Skip to content

Commit def7d45

Browse files
author
noemotiovon
committed
[CANN]support sin cos argmax
Signed-off-by: noemotiovon <[email protected]>
1 parent 42eb248 commit def7d45

File tree

3 files changed

+134
-0
lines changed

3 files changed

+134
-0
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include <aclnnop/aclnn_triu.h>
5252
#include <aclnnop/aclnn_upsample_nearest_2d.h>
5353
#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
54+
#include <aclnnop/aclnn_argmax.h>
5455
#include <float.h>
5556

5657
#include <cmath>
@@ -180,6 +181,55 @@ static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
180181
ACL_CHECK(aclDestroyScalar(alpha));
181182
}
182183

184+
/**
185+
* @brief Computes the argmax of a tensor along the specified dimension using the CANN backend.
186+
*
187+
* This function performs the argmax operation on the input tensor (`acl_src`)
188+
* and stores the result in the destination tensor (`acl_dst`). The argmax is
189+
* computed along a specified axis, and the result is the index of the maximum value
190+
* along that axis. The operation is performed using the CANN backend, and
191+
* necessary memory allocation is handled automatically.
192+
*
193+
* @param ctx The context for CANN backend operations.
194+
* @param acl_src The source tensor on which the argmax operation will be performed.
195+
* @param acl_dst The destination tensor that will hold the resulting indices.
196+
* @param dst The destination tensor object that stores the result after the argmax operation.
197+
*/
198+
static void aclnn_argmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
199+
aclTensor* acl_dst, ggml_tensor* dst) {
200+
ggml_cann_pool_alloc dst_buffer_allocator(
201+
ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type));
202+
void* buffer = dst_buffer_allocator.get();
203+
int64_t dst_buffer_ne[4] = {1, dst->ne[0], dst->ne[1], dst->ne[2]};
204+
size_t dst_buffer_nb[GGML_MAX_DIMS];
205+
dst_buffer_nb[0] = ggml_type_size(dst->type);
206+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
207+
dst_buffer_nb[i] = dst_buffer_nb[i - 1] * dst_buffer_ne[i - 1];
208+
}
209+
210+
aclTensor* dst_buffer_tensor =
211+
ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
212+
dst_buffer_ne, dst_buffer_nb, 4);
213+
214+
uint64_t workspaceSize = 0;
215+
aclOpExecutor* executor;
216+
void* workspaceAddr = nullptr;
217+
218+
219+
ACL_CHECK(aclnnArgMaxGetWorkspaceSize(acl_src, 3, true, dst_buffer_tensor,
220+
&workspaceSize, &executor));
221+
if (workspaceSize > 0) {
222+
ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
223+
workspaceAddr = workspace_allocator.get();
224+
}
225+
226+
ACL_CHECK(aclnnArgMax(workspaceAddr, workspaceSize, executor, ctx.stream()));
227+
228+
size_t cpy_size = ggml_nbytes(dst);
229+
ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, buffer, cpy_size,
230+
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
231+
}
232+
183233
void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
184234
ggml_tensor* src0 = dst->src[0];
185235
ggml_tensor* src1 = dst->src[1];
@@ -3444,3 +3494,34 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
34443494
ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
34453495
ACL_CHECK(aclDestroyTensor(acl_dst));
34463496
}
3497+
3498+
3499+
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3500+
ggml_tensor * src0 = dst->src[0];
3501+
3502+
aclTensor* acl_src = ggml_cann_create_tensor(src0);
3503+
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3504+
aclnn_argmax(ctx, acl_src, acl_dst, dst);
3505+
ACL_CHECK(aclDestroyTensor(acl_src));
3506+
ACL_CHECK(aclDestroyTensor(acl_dst));
3507+
}
3508+
3509+
void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3510+
ggml_tensor * src0 = dst->src[0];
3511+
3512+
aclTensor* acl_src = ggml_cann_create_tensor(src0);
3513+
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3514+
aclnn_cos(ctx, acl_src, acl_dst);
3515+
ACL_CHECK(aclDestroyTensor(acl_src));
3516+
ACL_CHECK(aclDestroyTensor(acl_dst));
3517+
}
3518+
3519+
void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3520+
ggml_tensor * src0 = dst->src[0];
3521+
3522+
aclTensor* acl_src = ggml_cann_create_tensor(src0);
3523+
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3524+
aclnn_sin(ctx, acl_src, acl_dst);
3525+
ACL_CHECK(aclDestroyTensor(acl_src));
3526+
ACL_CHECK(aclDestroyTensor(acl_dst));
3527+
}

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,47 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
484484
*/
485485
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
486486

487+
/**
488+
* @brief Computes the index of the maximum value along the specified dimension
489+
* of a ggml tensor using the CANN backend.
490+
*
491+
* @details This function performs an argmax operation on the input tensor.
492+
* It finds the index of the maximum value along the specified axis
493+
* and stores these indices in the destination tensor `dst`. The
494+
* operation is executed using the CANN backend for optimized performance.
495+
*
496+
* @param ctx The CANN context used for operations.
497+
* @param dst The destination tensor where the indices of the maximum values will be stored.
498+
* dst->op is `GGML_OP_ARGMAX`.
499+
*/
500+
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
501+
502+
/**
503+
* @brief Computes the cosine of each element in a ggml tensor using the CANN backend.
504+
*
505+
* @details This function applies the cosine function element-wise to the input tensor.
506+
* The computed cosine values are stored in the destination tensor `dst`.
507+
* The operation is optimized using the CANN backend for improved performance.
508+
*
509+
* @param ctx The CANN context used for operations.
510+
* @param dst The destination tensor where the cosine values will be stored.
511+
* dst->op is `GGML_OP_COS`.
512+
*/
513+
void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst);
514+
515+
/**
516+
* @brief Computes the sine of each element in a ggml tensor using the CANN backend.
517+
*
518+
* @details This function applies the sine function element-wise to the input tensor.
519+
* The computed sine values are stored in the destination tensor `dst`.
520+
* The operation is optimized using the CANN backend for improved performance.
521+
*
522+
* @param ctx The CANN context used for operations.
523+
* @param dst The destination tensor where the sine values will be stored.
524+
* dst->op is `GGML_OP_SIN`.
525+
*/
526+
void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst);
527+
487528
template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
488529
aclTensor*, uint64_t*, aclOpExecutor**),
489530
aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,15 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
14201420
case GGML_OP_ARGSORT:
14211421
ggml_cann_argsort(ctx, dst);
14221422
break;
1423+
case GGML_OP_ARGMAX:
1424+
ggml_cann_argmax(ctx, dst);
1425+
break;
1426+
case GGML_OP_COS:
1427+
ggml_cann_cos(ctx, dst);
1428+
break;
1429+
case GGML_OP_SIN:
1430+
ggml_cann_sin(ctx, dst);
1431+
break;
14231432
default:
14241433
return false;
14251434
}
@@ -1794,6 +1803,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17941803
case GGML_OP_ARANGE:
17951804
case GGML_OP_TIMESTEP_EMBEDDING:
17961805
case GGML_OP_LEAKY_RELU:
1806+
case GGML_OP_ARGMAX:
1807+
case GGML_OP_COS:
1808+
case GGML_OP_SIN:
17971809
return true;
17981810
default:
17991811
return false;

0 commit comments

Comments
 (0)