Skip to content

Commit 3c87db4

Browse files
committed
codestyle adjustment
Signed-off-by: noemotiovon <[email protected]>
1 parent c24b995 commit 3c87db4

File tree

1 file changed

+69
-58
lines changed

1 file changed

+69
-58
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -885,71 +885,66 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
885885
}
886886

887887
/**
888-
* @brief Get or expand cached float32 tensors filled with scalar values.
889-
*
890-
* This function manages a cache of float32 tensors (zero-filled and one-filled).
891-
* If the cache does not exist, it will initialize the cache with a zero tensor
892-
* and a one tensor. If the requested tensor size exceeds the current cache
893-
* capacity, the cache will be expanded accordingly. The function then returns
894-
* an aclTensor created from the cached memory (either zero-filled or one-filled),
895-
* depending on the input `value`.
896-
*
897-
* @param ctx The CANN backend context that manages cache memory.
898-
* @param ne The tensor shape array (number of elements in each dimension).
899-
* @param nb The stride size for each dimension.
900-
* @param dims The number of tensor dimensions.
901-
* @param value The scalar value (only supports 0 or 1) used to determine whether
902-
* to return the zero-cache tensor or the one-cache tensor.
903-
* @return An aclTensor pointer corresponding to the cached tensor.
888+
* @brief Get or expand a cached float32 tensor filled with a scalar value.
889+
*
890+
* This function manages cached device memory for float32 tensors. If the current
891+
* cache size is insufficient for the requested tensor shape, the old memory will
892+
* be released and new memory will be allocated. The allocated buffer is then
893+
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
894+
* value using CANN operations. Finally, an aclTensor object is created from the
895+
* cached memory and returned.
896+
*
897+
* @param ctx The CANN backend context that manages device memory.
898+
* @param buffer A pointer to the cached device buffer (will be allocated
899+
* or reallocated if necessary).
900+
* @param cache_element The current number of cached elements. This will be
901+
* updated when the cache is expanded.
902+
* @param ne The tensor shape array (number of elements in each dimension).
903+
* @param nb The stride size for each dimension.
904+
* @param dims The number of tensor dimensions.
905+
* @param value The scalar value used to fill the tensor (supports zero
906+
* initialization via memset or arbitrary values via fill_scalar).
907+
* @return An aclTensor pointer created from the cached buffer.
904908
*/
905-
static aclTensor* get_f32_zero_or_one_cache_acl_tensor(ggml_backend_cann_context& ctx,
906-
int64_t* ne, size_t* nb,
907-
int64_t dims, int64_t value) {
908-
// just support one and zero cache
909-
GGML_ASSERT(value == 1 || value == 0);
910-
911-
// Cache init or expansion
909+
static aclTensor* get_f32_cache_acl_tensor(
910+
ggml_backend_cann_context& ctx,
911+
void** buffer,
912+
int64_t &cache_element,
913+
int64_t* ne,
914+
size_t* nb,
915+
int64_t dims,
916+
float value) {
917+
// Calculate total number of elements
912918
int64_t n_element = 1;
913-
for(int i = 0; i < dims; i++) {
914-
n_element = n_element * ne[i];
919+
for (int i = 0; i < dims; i++) {
920+
n_element *= ne[i];
915921
}
916922
size_t size = n_element * sizeof(float);
917-
void* cache;
918-
if (value == 0) {
919-
if(ctx.f32_zero_cache_element < n_element){
920-
//free old mem
921-
if(ctx.f32_zero_cache != nullptr){
922-
aclrtFree(ctx.f32_zero_cache);
923-
}
924-
925-
//init zero cache
926-
ctx.f32_zero_cache_element = n_element;
927-
ACL_CHECK(aclrtMalloc(&ctx.f32_zero_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
928-
ACL_CHECK(aclrtMemsetAsync(ctx.f32_zero_cache, size, 0, size, ctx.stream()));
923+
924+
// Allocate or expand cache if needed
925+
if (cache_element < n_element) {
926+
if (*buffer != nullptr) {
927+
aclrtFree(*buffer);
928+
*buffer = nullptr;
929929
}
930-
cache = ctx.f32_zero_cache;
931-
} else {
932-
if(ctx.f32_one_cache_element < n_element){
933-
//free old mem
934-
if(ctx.f32_one_cache != nullptr){
935-
aclrtFree(ctx.f32_one_cache);
936-
}
937-
938-
//init one cache
939-
ctx.f32_one_cache_element = n_element;
930+
931+
ACL_CHECK(aclrtMalloc(buffer, size, ACL_MEM_MALLOC_HUGE_FIRST));
932+
cache_element = n_element;
933+
934+
// Initialize cache
935+
if (value == 0.0f) {
936+
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
937+
} else {
940938
int64_t pool_ne[1] = { n_element };
941939
size_t pool_nb[1] = { sizeof(float) };
942-
ACL_CHECK(aclrtMalloc(&ctx.f32_one_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
943-
aclTensor* acl_one = ggml_cann_create_tensor(
944-
ctx.f32_one_cache, ACL_FLOAT, sizeof(float), pool_ne, pool_nb,
945-
1);
946-
aclnn_fill_scalar(ctx, 1, acl_one);
947-
ggml_cann_release_resources(ctx, acl_one);
940+
aclTensor* acl_value = ggml_cann_create_tensor(
941+
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
942+
aclnn_fill_scalar(ctx, 1, acl_value);
943+
ggml_cann_release_resources(ctx, acl_value);
948944
}
949-
cache = ctx.f32_one_cache;
950945
}
951946

952-
return ggml_cann_create_tensor(cache, ACL_FLOAT, sizeof(float), ne, nb, dims);
947+
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
953948
}
954949

955950
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -967,15 +962,31 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
967962
for (int i = 1; i < GGML_MAX_DIMS; i++) {
968963
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
969964
}
970-
aclTensor* acl_gamma = get_f32_zero_or_one_cache_acl_tensor(ctx, src->ne, acl_gamma_nb, 1, 1);
965+
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
966+
ctx,
967+
&ctx.f32_one_cache,
968+
ctx.f32_one_cache_element,
969+
src->ne,
970+
acl_gamma_nb,
971+
1, // dims
972+
1.0f // value
973+
);
971974

972975
// build rstd, zero...
973976
size_t acl_rstd_nb[GGML_MAX_DIMS];
974977
acl_rstd_nb[0] = sizeof(float);
975978
for (int i = 1; i < GGML_MAX_DIMS; i++) {
976979
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
977980
}
978-
aclTensor* acl_rstd = get_f32_zero_or_one_cache_acl_tensor(ctx, src->ne, acl_rstd_nb, GGML_MAX_DIMS, 0);
981+
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
982+
ctx,
983+
&ctx.f32_zero_cache,
984+
ctx.f32_zero_cache_element,
985+
src->ne,
986+
acl_rstd_nb,
987+
GGML_MAX_DIMS,
988+
0.0f // value
989+
);
979990

980991
GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
981992
ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
@@ -996,7 +1007,7 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
9961007

9971008
aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type),
9981009
ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS);
999-
1010+
10001011
aclnn_fill_scalar(ctx, value, mask_tensor);
10011012

10021013
aclScalar* alpha = nullptr;

0 commit comments

Comments
 (0)