@@ -885,71 +885,66 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
885885}
886886
887887/* *
888- * @brief Get or expand cached float32 tensors filled with scalar values.
889- *
890- * This function manages a cache of float32 tensors (zero-filled and one-filled).
891- * If the cache does not exist, it will initialize the cache with a zero tensor
892- * and a one tensor. If the requested tensor size exceeds the current cache
893- * capacity, the cache will be expanded accordingly. The function then returns
894- * an aclTensor created from the cached memory (either zero-filled or one-filled),
895- * depending on the input `value`.
896- *
897- * @param ctx The CANN backend context that manages cache memory.
898- * @param ne The tensor shape array (number of elements in each dimension).
899- * @param nb The stride size for each dimension.
900- * @param dims The number of tensor dimensions.
901- * @param value The scalar value (only supports 0 or 1) used to determine whether
902- * to return the zero-cache tensor or the one-cache tensor.
903- * @return An aclTensor pointer corresponding to the cached tensor.
888+ * @brief Get or expand a cached float32 tensor filled with a scalar value.
889+ *
890+ * This function manages cached device memory for float32 tensors. If the current
891+ * cache size is insufficient for the requested tensor shape, the old memory will
892+ * be released and new memory will be allocated. The allocated buffer is then
893+ * initialized either with zeros (when @p value == 0.0f) or with the given scalar
894+ * value using CANN operations. Finally, an aclTensor object is created from the
895+ * cached memory and returned.
896+ *
897+ * @param ctx The CANN backend context that manages device memory.
898+ * @param buffer A pointer to the cached device buffer (will be allocated
899+ * or reallocated if necessary).
900+ * @param cache_element The current number of cached elements. This will be
901+ * updated when the cache is expanded.
902+ * @param ne The tensor shape array (number of elements in each dimension).
903+ * @param nb The stride size for each dimension.
904+ * @param dims The number of tensor dimensions.
905+ * @param value The scalar value used to fill the tensor (supports zero
906+ * initialization via memset or arbitrary values via fill_scalar).
907+ * @return An aclTensor pointer created from the cached buffer.
904908 */
905- static aclTensor* get_f32_zero_or_one_cache_acl_tensor (ggml_backend_cann_context& ctx,
906- int64_t * ne, size_t * nb,
907- int64_t dims, int64_t value) {
908- // just support one and zero cache
909- GGML_ASSERT (value == 1 || value == 0 );
910-
911- // Cache init or expansion
909+ static aclTensor* get_f32_cache_acl_tensor (
910+ ggml_backend_cann_context& ctx,
911+ void ** buffer,
912+ int64_t &cache_element,
913+ int64_t * ne,
914+ size_t * nb,
915+ int64_t dims,
916+ float value) {
917+ // Calculate total number of elements
912918 int64_t n_element = 1 ;
913- for (int i = 0 ; i < dims; i++) {
914- n_element = n_element * ne[i];
919+ for (int i = 0 ; i < dims; i++) {
920+ n_element *= ne[i];
915921 }
916922 size_t size = n_element * sizeof (float );
917- void * cache;
918- if (value == 0 ) {
919- if (ctx.f32_zero_cache_element < n_element){
920- // free old mem
921- if (ctx.f32_zero_cache != nullptr ){
922- aclrtFree (ctx.f32_zero_cache );
923- }
924-
925- // init zero cache
926- ctx.f32_zero_cache_element = n_element;
927- ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
928- ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
923+
924+ // Allocate or expand cache if needed
925+ if (cache_element < n_element) {
926+ if (*buffer != nullptr ) {
927+ aclrtFree (*buffer);
928+ *buffer = nullptr ;
929929 }
930- cache = ctx.f32_zero_cache ;
931- } else {
932- if (ctx.f32_one_cache_element < n_element){
933- // free old mem
934- if (ctx.f32_one_cache != nullptr ){
935- aclrtFree (ctx.f32_one_cache );
936- }
937-
938- // init one cache
939- ctx.f32_one_cache_element = n_element;
930+
931+ ACL_CHECK (aclrtMalloc (buffer, size, ACL_MEM_MALLOC_HUGE_FIRST));
932+ cache_element = n_element;
933+
934+ // Initialize cache
935+ if (value == 0 .0f ) {
936+ ACL_CHECK (aclrtMemsetAsync (*buffer, size, 0 , size, ctx.stream ()));
937+ } else {
940938 int64_t pool_ne[1 ] = { n_element };
941939 size_t pool_nb[1 ] = { sizeof (float ) };
942- ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
943- aclTensor* acl_one = ggml_cann_create_tensor (
944- ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
945- 1 );
946- aclnn_fill_scalar (ctx, 1 , acl_one);
947- ggml_cann_release_resources (ctx, acl_one);
940+ aclTensor* acl_value = ggml_cann_create_tensor (
941+ *buffer, ACL_FLOAT, sizeof (float ), pool_ne, pool_nb, 1 );
942+ aclnn_fill_scalar (ctx, 1 , acl_value);
943+ ggml_cann_release_resources (ctx, acl_value);
948944 }
949- cache = ctx.f32_one_cache ;
950945 }
951946
952- return ggml_cann_create_tensor (cache , ACL_FLOAT, sizeof (float ), ne, nb, dims);
947+ return ggml_cann_create_tensor (*buffer , ACL_FLOAT, sizeof (float ), ne, nb, dims);
953948}
954949
955950void ggml_cann_rms_norm (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -967,15 +962,31 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
967962 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
968963 acl_gamma_nb[i] = acl_gamma_nb[i - 1 ] * src->ne [i - 1 ];
969964 }
970- aclTensor* acl_gamma = get_f32_zero_or_one_cache_acl_tensor (ctx, src->ne , acl_gamma_nb, 1 , 1 );
965+ aclTensor* acl_gamma = get_f32_cache_acl_tensor (
966+ ctx,
967+ &ctx.f32_one_cache ,
968+ ctx.f32_one_cache_element ,
969+ src->ne ,
970+ acl_gamma_nb,
971+ 1 , // dims
972+ 1 .0f // value
973+ );
971974
972975 // build rstd, zero...
973976 size_t acl_rstd_nb[GGML_MAX_DIMS];
974977 acl_rstd_nb[0 ] = sizeof (float );
975978 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
976979 acl_rstd_nb[i] = acl_rstd_nb[i - 1 ] * src->ne [i - 1 ];
977980 }
978- aclTensor* acl_rstd = get_f32_zero_or_one_cache_acl_tensor (ctx, src->ne , acl_rstd_nb, GGML_MAX_DIMS, 0 );
981+ aclTensor* acl_rstd = get_f32_cache_acl_tensor (
982+ ctx,
983+ &ctx.f32_zero_cache ,
984+ ctx.f32_zero_cache_element ,
985+ src->ne ,
986+ acl_rstd_nb,
987+ GGML_MAX_DIMS,
988+ 0 .0f // value
989+ );
979990
980991 GGML_CANN_CALL_ACLNN_OP (ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
981992 ggml_cann_release_resources (ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
@@ -996,7 +1007,7 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
9961007
9971008 aclTensor* mask_tensor = ggml_cann_create_tensor (buffer, ggml_cann_type_mapping (src->type ),
9981009 ggml_type_size (src->type ), src->ne , src->nb , GGML_MAX_DIMS);
999-
1010+
10001011 aclnn_fill_scalar (ctx, value, mask_tensor);
10011012
10021013 aclScalar* alpha = nullptr ;
0 commit comments