Skip to content

Commit c24b995

Browse files
committed
fix review comment
Signed-off-by: noemotiovon <[email protected]>
1 parent 9b0ec0e commit c24b995

File tree

2 files changed

+43
-50
lines changed

2 files changed

+43
-50
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -902,57 +902,50 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
902902
* to return the zero-cache tensor or the one-cache tensor.
903903
* @return An aclTensor pointer corresponding to the cached tensor.
904904
*/
905-
static aclTensor* get_f32_cache_acl_tensor(ggml_backend_cann_context& ctx,
905+
static aclTensor* get_f32_zero_or_one_cache_acl_tensor(ggml_backend_cann_context& ctx,
906906
int64_t* ne, size_t* nb,
907907
int64_t dims, int64_t value) {
908-
// init cache
909-
if(ctx.f32_zero_cache == nullptr) {
910-
// zero-cache pool init
911-
size_t size = ctx.f32_cache_element * sizeof(float);
912-
ACL_CHECK(aclrtMalloc(&ctx.f32_zero_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
913-
ACL_CHECK(aclrtMemsetAsync(ctx.f32_zero_cache, size, 0, size, ctx.stream()));
914-
915-
// one-cache pool init
916-
int64_t pool_ne[1] = { ctx.f32_cache_element };
917-
size_t pool_nb[1] = { sizeof(float) };
918-
ACL_CHECK(aclrtMalloc(&ctx.f32_one_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
919-
aclTensor* acl_one = ggml_cann_create_tensor(
920-
ctx.f32_one_cache, ACL_FLOAT, sizeof(float), pool_ne, pool_nb,
921-
1);
922-
aclnn_fill_scalar(ctx, 1, acl_one);
923-
ggml_cann_release_resources(ctx, acl_one);
924-
}
908+
// just support one and zero cache
909+
GGML_ASSERT(value == 1 || value == 0);
925910

926-
// Cache expansion
911+
// Cache init or expansion
927912
int64_t n_element = 1;
928913
for(int i = 0; i < dims; i++) {
929914
n_element = n_element * ne[i];
930915
}
931-
if (ctx.f32_cache_element < n_element) {
932-
// free old mem
933-
aclrtFree(ctx.f32_zero_cache);
934-
aclrtFree(ctx.f32_one_cache);
935-
// init zero cache
936-
ctx.f32_cache_element = n_element;
937-
size_t size = n_element * sizeof(float);
938-
ACL_CHECK(aclrtMalloc(&ctx.f32_zero_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
939-
ACL_CHECK(aclrtMemsetAsync(ctx.f32_zero_cache, size, 0, size, ctx.stream()));
940-
941-
// one-cache pool init
942-
int64_t pool_ne[1] = { n_element };
943-
size_t pool_nb[1] = { sizeof(float) };
944-
ACL_CHECK(aclrtMalloc(&ctx.f32_one_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
945-
aclTensor* acl_one = ggml_cann_create_tensor(
946-
ctx.f32_one_cache, ACL_FLOAT, sizeof(float), pool_ne, pool_nb,
947-
1);
948-
aclnn_fill_scalar(ctx, 1, acl_one);
949-
ggml_cann_release_resources(ctx, acl_one);
950-
}
951-
916+
size_t size = n_element * sizeof(float);
952917
void* cache;
953918
if (value == 0) {
919+
if(ctx.f32_zero_cache_element < n_element){
920+
//free old mem
921+
if(ctx.f32_zero_cache != nullptr){
922+
aclrtFree(ctx.f32_zero_cache);
923+
}
924+
925+
//init zero cache
926+
ctx.f32_zero_cache_element = n_element;
927+
ACL_CHECK(aclrtMalloc(&ctx.f32_zero_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
928+
ACL_CHECK(aclrtMemsetAsync(ctx.f32_zero_cache, size, 0, size, ctx.stream()));
929+
}
954930
cache = ctx.f32_zero_cache;
955931
} else {
932+
if(ctx.f32_one_cache_element < n_element){
933+
//free old mem
934+
if(ctx.f32_one_cache != nullptr){
935+
aclrtFree(ctx.f32_one_cache);
936+
}
937+
938+
//init one cache
939+
ctx.f32_one_cache_element = n_element;
940+
int64_t pool_ne[1] = { n_element };
941+
size_t pool_nb[1] = { sizeof(float) };
942+
ACL_CHECK(aclrtMalloc(&ctx.f32_one_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
943+
aclTensor* acl_one = ggml_cann_create_tensor(
944+
ctx.f32_one_cache, ACL_FLOAT, sizeof(float), pool_ne, pool_nb,
945+
1);
946+
aclnn_fill_scalar(ctx, 1, acl_one);
947+
ggml_cann_release_resources(ctx, acl_one);
948+
}
956949
cache = ctx.f32_one_cache;
957950
}
958951

@@ -974,15 +967,15 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
974967
for (int i = 1; i < GGML_MAX_DIMS; i++) {
975968
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
976969
}
977-
aclTensor* acl_gamma = get_f32_cache_acl_tensor(ctx, src->ne, acl_gamma_nb, 1, 1);
970+
aclTensor* acl_gamma = get_f32_zero_or_one_cache_acl_tensor(ctx, src->ne, acl_gamma_nb, 1, 1);
978971

979972
// build rstd, zero...
980973
size_t acl_rstd_nb[GGML_MAX_DIMS];
981974
acl_rstd_nb[0] = sizeof(float);
982975
for (int i = 1; i < GGML_MAX_DIMS; i++) {
983976
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
984977
}
985-
aclTensor* acl_rstd = get_f32_cache_acl_tensor(ctx, src->ne, acl_rstd_nb, GGML_MAX_DIMS, 0);
978+
aclTensor* acl_rstd = get_f32_zero_or_one_cache_acl_tensor(ctx, src->ne, acl_rstd_nb, GGML_MAX_DIMS, 0);
986979

987980
GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
988981
ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
@@ -998,14 +991,13 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
998991

999992
const int n_past = ((int32_t*)dst->op_params)[0];
1000993

1001-
size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
1002-
src->ne[3] * ggml_element_size(src);
1003-
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
994+
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src));
995+
void* buffer = one_tensor_allocator.get();
1004996

1005-
aclTensor* mask_tensor =
1006-
aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes,
1007-
src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
1008-
ggml_element_size(src), value);
997+
aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type),
998+
ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS);
999+
1000+
aclnn_fill_scalar(ctx, value, mask_tensor);
10091001

10101002
aclScalar* alpha = nullptr;
10111003
float alphaValue = 1.0f;

ggml/src/ggml-cann/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ struct ggml_backend_cann_context {
381381
bool support_set_rows;
382382
void* f32_zero_cache = nullptr;
383383
void* f32_one_cache = nullptr;
384-
int64_t f32_cache_element = 1024 * 1024;
384+
int64_t f32_zero_cache_element = 0;
385+
int64_t f32_one_cache_element = 0;
385386

386387
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
387388

0 commit comments

Comments
 (0)