@@ -99,7 +99,19 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
9999 ACL_CHECK (aclDestroyIntArray (repeats));
100100}
101101
102- void ggml_cann_cast (ggml_backend_cann_context& ctx, aclTensor* acl_src,
102+ /* *
103+ * @brief Casts the elements of a tensor to a specified data type using the CANN backend.
104+ *
105+ * @details This function performs a type conversion on the elements of the input tensor `acl_src`
106+ * and stores the results in the destination tensor `acl_dst`. The conversion type is
107+ * determined based on the `dst` tensor's data type.
108+ *
109+ * @param ctx The context for the CANN backend operations.
110+ * @param acl_src The source tensor whose elements will be cast.
111+ * @param acl_dst The destination tensor that will store the casted elements.
112+ * @param dst The ggml tensor specifying the target data type.
113+ */
114+ static void aclnn_cast (ggml_backend_cann_context& ctx, aclTensor* acl_src,
103115 aclTensor* acl_dst, ggml_tensor* dst) {
104116 uint64_t workspaceSize = 0 ;
105117 aclOpExecutor* executor;
@@ -914,7 +926,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
914926 if (dst->type == src0->type ) {
915927 cann_copy (ctx, acl_src, acl_dst);
916928 } else {
917- ggml_cann_cast (ctx, acl_src, acl_dst, dst);
929+ aclnn_cast (ctx, acl_src, acl_dst, dst);
918930 }
919931 } else {
920932 if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst)) {
@@ -939,7 +951,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
939951 ggml_type_size (dst->type ), src0->ne , src_trans_nb,
940952 GGML_MAX_DIMS);
941953
942- ggml_cann_cast (ctx, acl_src, src_trans_tensor, dst);
954+ aclnn_cast (ctx, acl_src, src_trans_tensor, dst);
943955 size_t cpy_size = ggml_nbytes (dst);
944956 ACL_CHECK (aclrtMemcpyAsync (
945957 dst->data , cpy_size, src_trans_buffer, cpy_size,
@@ -961,7 +973,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
961973 ggml_type_size (dst->type ), src0->ne , src_trans_nb,
962974 GGML_MAX_DIMS);
963975
964- ggml_cann_cast (ctx, acl_src, src_trans_tensor, dst);
976+ aclnn_cast (ctx, acl_src, src_trans_tensor, dst);
965977
966978 size_t cpy_size = ggml_nbytes (dst);
967979 ACL_CHECK (aclrtMemcpyAsync (dst->data , cpy_size, src_trans_buffer,
@@ -2298,7 +2310,22 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
22982310 ACL_CHECK (aclDestroyTensor (tmp_mask_tensor));
22992311}
23002312
2301- void ggml_cann_embedding_4d (ggml_backend_cann_context& ctx, void * src_buffer,
2313+ /* *
2314+ * @brief Performs embedding operation on a 4D tensor using the CANN backend.
2315+ *
2316+ * @details This function extracts slices from the source tensor (`src_buffer`), index tensor (`index`),
2317+ * and destination tensor (`dst`), and performs an embedding operation on them. The embedding
2318+ * operation is applied by iterating over the last two dimensions of the source tensor, creating
2319+ * the necessary tensors for the source, index, and output, and executing the embedding operation.
2320+ *
2321+ * @param ctx The context for CANN backend operations.
2322+ * @param src_buffer The source buffer holding the data for the source tensor.
2323+ * @param src_ne The dimensions of the source tensor.
2324+ * @param src_nb The strides (byte offsets) of the source tensor.
2325+ * @param index The index tensor used in the embedding operation.
2326+ * @param dst The destination tensor where the result will be stored.
2327+ */
2328+ static void aclnn_embedding_4d (ggml_backend_cann_context& ctx, void * src_buffer,
23022329 int64_t * src_ne, size_t * src_nb, ggml_tensor* index,
23032330 ggml_tensor* dst) {
23042331 for (int64_t i = 0 ; i < src_ne[3 ]; i++) {
@@ -2356,7 +2383,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
23562383
23572384 switch (src0->type ) {
23582385 case GGML_TYPE_F32: {
2359- ggml_cann_embedding_4d (ctx, src0->data , src0->ne , src0->nb , src1,
2386+ aclnn_embedding_4d (ctx, src0->data , src0->ne , src0->nb , src1,
23602387 dst);
23612388 break ;
23622389 }
@@ -2373,8 +2400,8 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
23732400 aclTensor* src_trans_tensor = ggml_cann_create_tensor (
23742401 src_trans_buffer, ACL_FLOAT, ggml_type_size (dst->type ),
23752402 src0->ne , src_trans_nb, GGML_MAX_DIMS);
2376- ggml_cann_cast (ctx, acl_src0, src_trans_tensor, dst);
2377- ggml_cann_embedding_4d (ctx, src_trans_buffer, src0->ne ,
2403+ aclnn_cast (ctx, acl_src0, src_trans_tensor, dst);
2404+ aclnn_embedding_4d (ctx, src_trans_buffer, src0->ne ,
23782405 src_trans_nb, src1, dst);
23792406 ACL_CHECK (aclDestroyTensor (acl_src0));
23802407 ACL_CHECK (aclDestroyTensor (src_trans_tensor));
@@ -2436,7 +2463,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24362463 dequant_nb[i] = dequant_nb[i - 1 ] * src0->ne [i - 1 ];
24372464 }
24382465
2439- ggml_cann_embedding_4d (ctx, dequant_buffer_allocator.get (),
2466+ aclnn_embedding_4d (ctx, dequant_buffer_allocator.get (),
24402467 dequant_ne, dequant_nb, src1, dst);
24412468
24422469 ACL_CHECK (aclDestroyTensor (dequant_tensor));
0 commit comments