@@ -902,57 +902,50 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
902902 * to return the zero-cache tensor or the one-cache tensor.
903903 * @return An aclTensor pointer corresponding to the cached tensor.
904904 */
905- static aclTensor* get_f32_cache_acl_tensor (ggml_backend_cann_context& ctx,
905+ static aclTensor* get_f32_zero_or_one_cache_acl_tensor (ggml_backend_cann_context& ctx,
906906 int64_t * ne, size_t * nb,
907907 int64_t dims, int64_t value) {
908- // init cache
909- if (ctx.f32_zero_cache == nullptr ) {
910- // zero-cache pool init
911- size_t size = ctx.f32_cache_element * sizeof (float );
912- ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
913- ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
914-
915- // one-cache pool init
916- int64_t pool_ne[1 ] = { ctx.f32_cache_element };
917- size_t pool_nb[1 ] = { sizeof (float ) };
918- ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
919- aclTensor* acl_one = ggml_cann_create_tensor (
920- ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
921- 1 );
922- aclnn_fill_scalar (ctx, 1 , acl_one);
923- ggml_cann_release_resources (ctx, acl_one);
924- }
908+ // just support one and zero cache
909+ GGML_ASSERT (value == 1 || value == 0 );
925910
926- // Cache expansion
911+ // Cache init or expansion
927912 int64_t n_element = 1 ;
928913 for (int i = 0 ; i < dims; i++) {
929914 n_element = n_element * ne[i];
930915 }
931- if (ctx.f32_cache_element < n_element) {
932- // free old mem
933- aclrtFree (ctx.f32_zero_cache );
934- aclrtFree (ctx.f32_one_cache );
935- // init zero cache
936- ctx.f32_cache_element = n_element;
937- size_t size = n_element * sizeof (float );
938- ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
939- ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
940-
941- // one-cache pool init
942- int64_t pool_ne[1 ] = { n_element };
943- size_t pool_nb[1 ] = { sizeof (float ) };
944- ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
945- aclTensor* acl_one = ggml_cann_create_tensor (
946- ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
947- 1 );
948- aclnn_fill_scalar (ctx, 1 , acl_one);
949- ggml_cann_release_resources (ctx, acl_one);
950- }
951-
916+ size_t size = n_element * sizeof (float );
952917 void * cache;
953918 if (value == 0 ) {
919+ if (ctx.f32_zero_cache_element < n_element){
920+ // free old mem
921+ if (ctx.f32_zero_cache != nullptr ){
922+ aclrtFree (ctx.f32_zero_cache );
923+ }
924+
925+ // init zero cache
926+ ctx.f32_zero_cache_element = n_element;
927+ ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
928+ ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
929+ }
954930 cache = ctx.f32_zero_cache ;
955931 } else {
932+ if (ctx.f32_one_cache_element < n_element){
933+ // free old mem
934+ if (ctx.f32_one_cache != nullptr ){
935+ aclrtFree (ctx.f32_one_cache );
936+ }
937+
938+ // init one cache
939+ ctx.f32_one_cache_element = n_element;
940+ int64_t pool_ne[1 ] = { n_element };
941+ size_t pool_nb[1 ] = { sizeof (float ) };
942+ ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
943+ aclTensor* acl_one = ggml_cann_create_tensor (
944+ ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
945+ 1 );
946+ aclnn_fill_scalar (ctx, 1 , acl_one);
947+ ggml_cann_release_resources (ctx, acl_one);
948+ }
956949 cache = ctx.f32_one_cache ;
957950 }
958951
@@ -974,15 +967,15 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
974967 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
975968 acl_gamma_nb[i] = acl_gamma_nb[i - 1 ] * src->ne [i - 1 ];
976969 }
977- aclTensor* acl_gamma = get_f32_cache_acl_tensor (ctx, src->ne , acl_gamma_nb, 1 , 1 );
970+ aclTensor* acl_gamma = get_f32_zero_or_one_cache_acl_tensor (ctx, src->ne , acl_gamma_nb, 1 , 1 );
978971
979972 // build rstd, zero...
980973 size_t acl_rstd_nb[GGML_MAX_DIMS];
981974 acl_rstd_nb[0 ] = sizeof (float );
982975 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
983976 acl_rstd_nb[i] = acl_rstd_nb[i - 1 ] * src->ne [i - 1 ];
984977 }
985- aclTensor* acl_rstd = get_f32_cache_acl_tensor (ctx, src->ne , acl_rstd_nb, GGML_MAX_DIMS, 0 );
978+ aclTensor* acl_rstd = get_f32_zero_or_one_cache_acl_tensor (ctx, src->ne , acl_rstd_nb, GGML_MAX_DIMS, 0 );
986979
987980 GGML_CANN_CALL_ACLNN_OP (ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
988981 ggml_cann_release_resources (ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
@@ -998,14 +991,13 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
998991
999992 const int n_past = ((int32_t *)dst->op_params )[0 ];
1000993
1001- size_t one_tensor_n_bytes = src->ne [0 ] * src->ne [1 ] * src->ne [2 ] *
1002- src->ne [3 ] * ggml_element_size (src);
1003- ggml_cann_pool_alloc one_tensor_allocator (ctx.pool (), one_tensor_n_bytes);
994+ ggml_cann_pool_alloc one_tensor_allocator (ctx.pool (), ggml_nbytes (src));
995+ void * buffer = one_tensor_allocator.get ();
1004996
1005- aclTensor* mask_tensor =
1006- aclnn_values (ctx, one_tensor_allocator. get (), one_tensor_n_bytes,
1007- src-> ne , GGML_MAX_DIMS, ggml_cann_type_mapping (src-> type ),
1008- ggml_element_size (src) , value);
997+ aclTensor* mask_tensor = ggml_cann_create_tensor (buffer, ggml_cann_type_mapping (src-> type ),
998+ ggml_type_size (src-> type ), src-> ne , src-> nb , GGML_MAX_DIMS);
999+
1000+ aclnn_fill_scalar (ctx , value, mask_tensor );
10091001
10101002 aclScalar* alpha = nullptr ;
10111003 float alphaValue = 1 .0f ;
0 commit comments