@@ -62,10 +62,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
6262 * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0. 
6363 * @return  Pointer to the created ACL tensor. 
6464 */  
65- aclTensor* ggml_cann_create_tensor (const  ggml_tensor* tensor, int64_t * ne = nullptr ,
66-                              size_t * nb = nullptr , int64_t  dims = 0 ,
67-                              aclFormat format = ACL_FORMAT_ND,
68-                              size_t  offset = 0 );
65+ aclTensor * ggml_cann_create_tensor (const  ggml_tensor * tensor,
66+                                     int64_t  *           ne     = nullptr ,
67+                                     size_t  *            nb     = nullptr ,
68+                                     int64_t              dims   = 0 ,
69+                                     aclFormat           format = ACL_FORMAT_ND,
70+                                     size_t               offset = 0 );
6971
7072/* *
7173 * @brief   Template for creating an ACL tensor from provided parameters. typename TYPE 
@@ -87,12 +89,15 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
8789 * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0. 
8890 * @return  Pointer to the created ACL tensor. 
8991 */  
90- template <typename  TYPE>
91- aclTensor* ggml_cann_create_tensor (void * data_ptr, aclDataType dtype,
92-                                    TYPE type_size, int64_t * ne, TYPE* nb,
93-                                    int64_t  dims,
94-                                    aclFormat format = ACL_FORMAT_ND,
95-                                    size_t  offset = 0 ) {
92+ template  <typename  TYPE>
93+ aclTensor * ggml_cann_create_tensor (void  *      data_ptr,
94+                                     aclDataType dtype,
95+                                     TYPE        type_size,
96+                                     int64_t  *   ne,
97+                                     TYPE *      nb,
98+                                     int64_t      dims,
99+                                     aclFormat   format = ACL_FORMAT_ND,
100+                                     size_t       offset = 0 ) {
96101    int64_t  tmp_ne[GGML_MAX_DIMS * 2 ];
97102    int64_t  tmp_stride[GGML_MAX_DIMS * 2 ];
98103
@@ -109,9 +114,8 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
109114    std::reverse (tmp_ne, tmp_ne + dims);
110115    std::reverse (tmp_stride, tmp_stride + dims);
111116
112-     aclTensor* acl_tensor =
113-         aclCreateTensor (tmp_ne, dims, dtype, tmp_stride, offset / type_size,
114-                         format, &acl_storage_len, 1 , data_ptr);
117+     aclTensor * acl_tensor =
118+         aclCreateTensor (tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1 , data_ptr);
115119
116120    return  acl_tensor;
117121}
@@ -132,7 +136,7 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
132136 *          to 1. If such a dimension is found, broadcasting is required to align t1 
133137 *          with t0 for element-wise operations. 
134138 */  
135- bool  ggml_cann_need_bcast (const  ggml_tensor* t0, const  ggml_tensor* t1);
139+ bool  ggml_cann_need_bcast (const  ggml_tensor  * t0, const  ggml_tensor  * t1);
136140
137141/* *
138142 * @brief   Computes broadcast shapes and strides for two ggml_tensors. 
@@ -187,19 +191,21 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
187191 *  dim1 in a inserted dim, should add nb for dim1, 
188192 *  and all other nb moves to next in order. 
189193 */  
190- int64_t  ggml_cann_get_bcast_shape (const  ggml_tensor* src0, const  ggml_tensor* src1,
191-                         int64_t * bcast_ne_src0, int64_t * bcast_ne_src1,
192-                         size_t * bcast_nb_src0, size_t * bcast_nb_src1);
194+ int64_t  ggml_cann_get_bcast_shape (const  ggml_tensor * src0,
195+                                   const  ggml_tensor * src1,
196+                                   int64_t  *           bcast_ne_src0,
197+                                   int64_t  *           bcast_ne_src1,
198+                                   size_t  *            bcast_nb_src0,
199+                                   size_t  *            bcast_nb_src1);
193200
194201//  Bcast macro to avoid duplicate code.
195- #define  BCAST_SHAPE (src0, src1 )                                              \
196-     int64_t  bcast_##src0##_ne[GGML_MAX_DIMS * 2 ];                            \
197-     int64_t  bcast_##src1##_ne[GGML_MAX_DIMS * 2 ];                            \
198-     size_t  bcast_##src0##_nb[GGML_MAX_DIMS * 2 ];                             \
199-     size_t  bcast_##src1##_nb[GGML_MAX_DIMS * 2 ];                             \
200-     int64_t  bcast_dims = ggml_cann_get_bcast_shape(                          \
201-         src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \
202-         bcast_##src1##_nb);
202+ #define  BCAST_SHAPE (src0, src1 )                                                                      \
203+     int64_t  bcast_##src0##_ne[GGML_MAX_DIMS * 2 ];                                                    \
204+     int64_t  bcast_##src1##_ne[GGML_MAX_DIMS * 2 ];                                                    \
205+     size_t   bcast_##src0##_nb[GGML_MAX_DIMS * 2 ];                                                    \
206+     size_t   bcast_##src1##_nb[GGML_MAX_DIMS * 2 ];                                                    \
207+     int64_t  bcast_dims = ggml_cann_get_bcast_shape(src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, \
208+                                                    bcast_##src0##_nb, bcast_##src1##_nb);
203209
204210#define  BCAST_PARAM (tensor ) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
205211
@@ -233,26 +239,31 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* sr
233239 *       before cast dim. 
234240 * @sa ggml_cann_get_bcast_shape 
235241 */  
236- int64_t  ggml_cann_get_mulmat_bcast_shape (
237-     const  int64_t * input_ne, const  int64_t * weight_ne, const  int64_t * dst_ne,
238-     const  size_t * input_nb, const  size_t * weight_nb, const  size_t * dst_nb,
239-     int64_t * bcast_input_ne, int64_t * bcast_weight_ne, int64_t * bcast_dst_ne,
240-     size_t * bcast_input_nb, size_t * bcast_weight_nb, size_t * bcast_dst_nb);
242+ int64_t  ggml_cann_get_mulmat_bcast_shape (const  int64_t  * input_ne,
243+                                          const  int64_t  * weight_ne,
244+                                          const  int64_t  * dst_ne,
245+                                          const  size_t  *  input_nb,
246+                                          const  size_t  *  weight_nb,
247+                                          const  size_t  *  dst_nb,
248+                                          int64_t  *       bcast_input_ne,
249+                                          int64_t  *       bcast_weight_ne,
250+                                          int64_t  *       bcast_dst_ne,
251+                                          size_t  *        bcast_input_nb,
252+                                          size_t  *        bcast_weight_nb,
253+                                          size_t  *        bcast_dst_nb);
241254
242255//  Bcast macro to avoid duplicate code.
243- #define  BCAST_MUL_MAT_SHAPE (input, weight, dst )                         \
244-     int64_t  bcast_##input##_ne[GGML_MAX_DIMS * 2 ];                      \
245-     int64_t  bcast_##weight##_ne[GGML_MAX_DIMS * 2 ];                     \
246-     int64_t  bcast_##dst##_ne[GGML_MAX_DIMS * 2 ];                        \
247-     size_t  bcast_##input##_nb[GGML_MAX_DIMS * 2 ];                       \
248-     size_t  bcast_##weight##_nb[GGML_MAX_DIMS * 2 ];                      \
249-     size_t  bcast_##dst##_nb[GGML_MAX_DIMS * 2 ];                         \
250-     int64_t  bcast_dims = ggml_cann_get_mulmat_bcast_shape(              \
251-         input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \
252-         bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne,      \
253-         bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
256+ #define  BCAST_MUL_MAT_SHAPE (input, weight, dst )                                                                  \
257+     int64_t  bcast_##input##_ne[GGML_MAX_DIMS * 2 ];                                                               \
258+     int64_t  bcast_##weight##_ne[GGML_MAX_DIMS * 2 ];                                                              \
259+     int64_t  bcast_##dst##_ne[GGML_MAX_DIMS * 2 ];                                                                 \
260+     size_t   bcast_##input##_nb[GGML_MAX_DIMS * 2 ];                                                               \
261+     size_t   bcast_##weight##_nb[GGML_MAX_DIMS * 2 ];                                                              \
262+     size_t   bcast_##dst##_nb[GGML_MAX_DIMS * 2 ];                                                                 \
263+     int64_t  bcast_dims = ggml_cann_get_mulmat_bcast_shape(                                                       \
264+         input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, bcast_##input##_ne, bcast_##weight##_ne, \
265+         bcast_##dst##_ne, bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
254266
255- #define  BCAST_MUL_MAT_PARAM (tensor ) \
256-     bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
267+ #define  BCAST_MUL_MAT_PARAM (tensor ) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
257268
258269#endif   //  CANN_ACL_TENSOR_H
0 commit comments