Skip to content

Commit dc6bf1a

Browse files
author
noemotiovon
committed
[CANN]code style adjustment
Signed-off-by: noemotiovon <[email protected]>
1 parent f34bf09 commit dc6bf1a

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,19 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
9999
ACL_CHECK(aclDestroyIntArray(repeats));
100100
}
101101

102-
void ggml_cann_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
102+
/**
103+
* @brief Casts the elements of a tensor to a specified data type using the CANN backend.
104+
*
105+
* @details This function performs a type conversion on the elements of the input tensor `acl_src`
106+
* and stores the results in the destination tensor `acl_dst`. The conversion type is
107+
* determined based on the `dst` tensor's data type.
108+
*
109+
* @param ctx The context for the CANN backend operations.
110+
* @param acl_src The source tensor whose elements will be cast.
111+
* @param acl_dst The destination tensor that will store the casted elements.
112+
* @param dst The ggml tensor specifying the target data type.
113+
*/
114+
static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
103115
aclTensor* acl_dst, ggml_tensor* dst) {
104116
uint64_t workspaceSize = 0;
105117
aclOpExecutor* executor;
@@ -914,7 +926,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
914926
if (dst->type == src0->type) {
915927
cann_copy(ctx, acl_src, acl_dst);
916928
} else {
917-
ggml_cann_cast(ctx, acl_src, acl_dst, dst);
929+
aclnn_cast(ctx, acl_src, acl_dst, dst);
918930
}
919931
} else {
920932
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
@@ -939,7 +951,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
939951
ggml_type_size(dst->type), src0->ne, src_trans_nb,
940952
GGML_MAX_DIMS);
941953

942-
ggml_cann_cast(ctx, acl_src, src_trans_tensor, dst);
954+
aclnn_cast(ctx, acl_src, src_trans_tensor, dst);
943955
size_t cpy_size = ggml_nbytes(dst);
944956
ACL_CHECK(aclrtMemcpyAsync(
945957
dst->data, cpy_size, src_trans_buffer, cpy_size,
@@ -961,7 +973,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
961973
ggml_type_size(dst->type), src0->ne, src_trans_nb,
962974
GGML_MAX_DIMS);
963975

964-
ggml_cann_cast(ctx, acl_src, src_trans_tensor, dst);
976+
aclnn_cast(ctx, acl_src, src_trans_tensor, dst);
965977

966978
size_t cpy_size = ggml_nbytes(dst);
967979
ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer,
@@ -2298,7 +2310,22 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
22982310
ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));
22992311
}
23002312

2301-
void ggml_cann_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer,
2313+
/**
2314+
* @brief Performs embedding operation on a 4D tensor using the CANN backend.
2315+
*
2316+
* @details This function extracts slices from the source tensor (`src_buffer`), index tensor (`index`),
2317+
* and destination tensor (`dst`), and performs an embedding operation on them. The embedding
2318+
* operation is applied by iterating over the last two dimensions of the source tensor, creating
2319+
* the necessary tensors for the source, index, and output, and executing the embedding operation.
2320+
*
2321+
* @param ctx The context for CANN backend operations.
2322+
* @param src_buffer The source buffer holding the data for the source tensor.
2323+
* @param src_ne The dimensions of the source tensor.
2324+
* @param src_nb The strides (byte offsets) of the source tensor.
2325+
* @param index The index tensor used in the embedding operation.
2326+
* @param dst The destination tensor where the result will be stored.
2327+
*/
2328+
static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer,
23022329
int64_t* src_ne, size_t* src_nb, ggml_tensor* index,
23032330
ggml_tensor* dst) {
23042331
for (int64_t i = 0; i < src_ne[3]; i++) {
@@ -2356,7 +2383,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
23562383

23572384
switch (src0->type) {
23582385
case GGML_TYPE_F32: {
2359-
ggml_cann_embedding_4d(ctx, src0->data, src0->ne, src0->nb, src1,
2386+
aclnn_embedding_4d(ctx, src0->data, src0->ne, src0->nb, src1,
23602387
dst);
23612388
break;
23622389
}
@@ -2373,8 +2400,8 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
23732400
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
23742401
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
23752402
src0->ne, src_trans_nb, GGML_MAX_DIMS);
2376-
ggml_cann_cast(ctx, acl_src0, src_trans_tensor, dst);
2377-
ggml_cann_embedding_4d(ctx, src_trans_buffer, src0->ne,
2403+
aclnn_cast(ctx, acl_src0, src_trans_tensor, dst);
2404+
aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne,
23782405
src_trans_nb, src1, dst);
23792406
ACL_CHECK(aclDestroyTensor(acl_src0));
23802407
ACL_CHECK(aclDestroyTensor(src_trans_tensor));
@@ -2436,7 +2463,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24362463
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
24372464
}
24382465

2439-
ggml_cann_embedding_4d(ctx, dequant_buffer_allocator.get(),
2466+
aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(),
24402467
dequant_ne, dequant_nb, src1, dst);
24412468

24422469
ACL_CHECK(aclDestroyTensor(dequant_tensor));

0 commit comments

Comments
 (0)