From c5e77de526f450c220aba44454215d206f5049dd Mon Sep 17 00:00:00 2001 From: hipudding Date: Thu, 28 Aug 2025 07:36:59 +0000 Subject: [PATCH 1/5] CANN: fix RoPE cache issue on multi-device RoPE cache only needs to be computed once per token. However, in multi-device scenarios, not every device starts computation from layer 0, which may lead to unallocated memory issues and precision errors. This commit records the first layer of each device to avoid the above issues. --- ggml/src/ggml-cann/aclnn_ops.cpp | 120 +++++++++++++++++-------------- ggml/src/ggml-cann/common.h | 58 +++++++++------ 2 files changed, 100 insertions(+), 78 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index c42871c575822..51f2aec80e71c 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } aclTensor* acl_gamma = get_f32_cache_acl_tensor( ctx, - &ctx.f32_one_cache, - ctx.f32_one_cache_element, + &ctx.rms_norm_one_tensor_cache.cache, + ctx.rms_norm_one_tensor_cache.size, src->ne, acl_gamma_nb, 1, // dims @@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } aclTensor* acl_rstd = get_f32_cache_acl_tensor( ctx, - &ctx.f32_zero_cache, - ctx.f32_zero_cache_element, + &ctx.rms_norm_zero_tensor_cache.cache, + ctx.rms_norm_zero_tensor_cache.size, src->ne, acl_rstd_nb, GGML_MAX_DIMS, @@ -2249,7 +2249,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * 6. Expand sin/cos values by repeat or repeat_interleave depending * on whether @param is_neox is enabled. * 7. Store the computed values into persistent buffers - * (ctx.rope_sin_ptr / ctx.rope_cos_ptr). + * (ctx.rope_cache.sin_cache / ctx.rope_cache.cos_cache). * * @param ctx The CANN backend context, holding memory pool, * stream, and persistent buffers for rope init/cache. @@ -2266,25 +2266,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, float attn_factor, bool is_neox) { // int sin/cos cache, cache has different repeat method depond on // @param.is_neox - bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0); - bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0); - - // used for accuracy testing - bool is_attention = is_q || is_k; - - // just compute in first layer in attention - bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0); - if(is_attention && !is_fisrt_layer) { - return; - } ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors - GGML_TENSOR_BINARY_OP_LOCALS + // get first layer in current device. + int layer = 0; + const char* dash = std::strchr(dst->name, '-'); + if (dash) { + layer = std::strtol(dash + 1, nullptr, 10); + } + + // remember the first layer. + if(ctx.rope_cache.first_layer == -1) + ctx.rope_cache.first_layer = layer; - int64_t theta_scale_length = ne00 / 2; + // only init cache when freq_factors is not null or first layer. + // dash == nullptr means we are in test-backend-ops + if(dash != nullptr && src2 == nullptr && layer != ctx.rope_cache.first_layer) { + // use cache. + return; + } + + int64_t theta_scale_length = src0->ne[0] / 2; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), theta_scale_length * sizeof(float_t)}; @@ -2302,21 +2307,24 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; } + // theta_scale arange, [0,1,...,ne00/2 - 1] + aclTensor* acl_theta_scale_tensor = nullptr; // init theta scale, just one time - if(ctx.rope_init_ptr == nullptr || !is_attention) { - // theta_scale arange, [0,1,...,ne00/2 - 1] - if(ctx.rope_init_ptr != nullptr){ - ACL_CHECK(aclrtFree(ctx.rope_init_ptr)); + // dash == nullptr means we are in test-backend-ops + if (ctx.rope_cache.theta_scale_cache == nullptr || dash == nullptr) { + if (ctx.rope_cache.theta_scale_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); } - ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + + acl_theta_scale_tensor = + ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - aclTensor* acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float start = 0; float step = 1; - float stop = ne00 / 2; - float n_elements = ne00 / 2; + float stop = src0->ne[0] / 2; + float n_elements = src0->ne[0] / 2; aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); // power @@ -2328,35 +2336,37 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, if (freq_scale != 1) { aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true); } + ggml_cann_release_resources(ctx, acl_theta_scale); + } else { + // use cache + acl_theta_scale_tensor = + ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + } - // freq_factors - if (src2) { - aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( - src2->data, ggml_cann_type_mapping(src2->type), - ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); - ggml_cann_release_resources(ctx, acl_freq_factors_tensor); - } - // release - ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale); + // freq_factors + if (src2) { + aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( + src2->data, ggml_cann_type_mapping(src2->type), + ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); + ggml_cann_release_resources(ctx, acl_freq_factors_tensor); } // init sin_repeat && cos_repeat, one token just init in 0 layer - if(position_length > ctx.max_prompt_length) { - ctx.max_prompt_length = position_length; - int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2; - if(ctx.rope_sin_ptr != nullptr) { - ACL_CHECK(aclrtFree(ctx.rope_sin_ptr)); - ACL_CHECK(aclrtFree(ctx.rope_cos_ptr)); + if (position_length > ctx.rope_cache.position_length) { + ctx.rope_cache.position_length = position_length; + if (ctx.rope_cache.sin_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache)); + } + if (ctx.rope_cache.cos_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache)); } - ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + int64_t repeat_theta_length = theta_scale_length * position_length * 2; + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); } - aclTensor* acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - // position aclTensor* acl_position_tensor = ggml_cann_create_tensor( src1->data, ggml_cann_type_mapping(src1->type), @@ -2397,17 +2407,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true); } - int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; + int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_repeat_tensor = - ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_repeat_tensor = - ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); // repeat @@ -2491,10 +2501,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_src = ggml_cann_create_tensor(src0); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 33794062f565d..d6bcde2fe73b3 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -360,6 +360,37 @@ struct ggml_cann_graph { }; #endif // USE_ACL_GRAPH +struct ggml_cann_rope_cache { + ~ggml_cann_rope_cache() { + if(theta_scale_cache != nullptr) { + ACL_CHECK(aclrtFree(theta_scale_cache)); + } + if(sin_cache != nullptr) { + ACL_CHECK(aclrtFree(sin_cache)); + } + if(cos_cache != nullptr) { + ACL_CHECK(aclrtFree(cos_cache)); + } + } + + void* theta_scale_cache = nullptr; + void* sin_cache = nullptr; + void* cos_cache = nullptr; + int first_layer = -1; + int64_t position_length = 0; +}; + +struct ggml_cann_tensor_cache { + ~ggml_cann_tensor_cache() { + if(cache != nullptr) { + ACL_CHECK(aclrtFree(cache)); + } + } + + void* cache = nullptr; + int64_t size = 0; +}; + /** * @brief Context for managing CANN backend operations. */ @@ -376,15 +407,11 @@ struct ggml_backend_cann_context { bool async_mode; bool support_set_rows; // Rope Cache - void* rope_init_ptr = nullptr; - void* rope_sin_ptr = nullptr; - void* rope_cos_ptr = nullptr; - int64_t max_prompt_length = 0; + ggml_cann_rope_cache rope_cache; // Constant Pool - void* f32_zero_cache = nullptr; - void* f32_one_cache = nullptr; - int64_t f32_zero_cache_element = 0; - int64_t f32_one_cache_element = 0; + ggml_cann_tensor_cache rms_norm_one_tensor_cache; + ggml_cann_tensor_cache rms_norm_zero_tensor_cache; + aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ @@ -424,21 +451,6 @@ struct ggml_backend_cann_context { ACL_CHECK(aclrtDestroyStream(streams[i])); } } - if(rope_init_ptr != nullptr) { - ACL_CHECK(aclrtFree(rope_init_ptr)); - } - if(rope_sin_ptr != nullptr) { - ACL_CHECK(aclrtFree(rope_sin_ptr)); - } - if(rope_cos_ptr != nullptr) { - ACL_CHECK(aclrtFree(rope_cos_ptr)); - } - if(f32_zero_cache != nullptr) { - ACL_CHECK(aclrtFree(f32_zero_cache)); - } - if(f32_one_cache != nullptr) { - ACL_CHECK(aclrtFree(f32_one_cache)); - } } /** From 24d43d305802ecf656721b4d7c152ceac7edc8fc Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 29 Aug 2025 02:16:13 +0000 Subject: [PATCH 2/5] CANN: Optimize first-layer detection method --- ggml/src/ggml-cann/aclnn_ops.cpp | 44 +++++++++++++++----------------- ggml/src/ggml-cann/common.h | 5 +++- ggml/src/ggml-cann/ggml-cann.cpp | 4 +++ 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 51f2aec80e71c..a5c4e08ee1375 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2271,24 +2271,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors - // get first layer in current device. - int layer = 0; - const char* dash = std::strchr(dst->name, '-'); - if (dash) { - layer = std::strtol(dash + 1, nullptr, 10); - } - - // remember the first layer. - if(ctx.rope_cache.first_layer == -1) - ctx.rope_cache.first_layer = layer; - - // only init cache when freq_factors is not null or first layer. - // dash == nullptr means we are in test-backend-ops - if(dash != nullptr && src2 == nullptr && layer != ctx.rope_cache.first_layer) { + if(src2 == nullptr && ctx.rope_cache.cached) { // use cache. return; } + // Other layers use cache except first layer. + ctx.rope_cache.cached = true; + int64_t theta_scale_length = src0->ne[0] / 2; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), @@ -2309,22 +2299,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, // theta_scale arange, [0,1,...,ne00/2 - 1] aclTensor* acl_theta_scale_tensor = nullptr; - // init theta scale, just one time - // dash == nullptr means we are in test-backend-ops - if (ctx.rope_cache.theta_scale_cache == nullptr || dash == nullptr) { + // cache theta scale + if (src2 != nullptr || ctx.rope_cache.theta_scale_length != theta_scale_length || + // theta_scale and freq_scale should not change during the current token inference process, + // so we can directly use == here instead of comparing the absolute difference. + ctx.rope_cache.theta_scale != theta_scale || + ctx.rope_cache.freq_scale != freq_scale) { + + ctx.rope_cache.theta_scale_length = theta_scale_length; + ctx.rope_cache.theta_scale = theta_scale; + ctx.rope_cache.freq_scale = freq_scale; + if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); } ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float start = 0; float step = 1; - float stop = src0->ne[0] / 2; - float n_elements = src0->ne[0] / 2; + float stop = theta_scale_length; + float n_elements = theta_scale_length; aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); // power @@ -2340,8 +2338,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } else { // use cache acl_theta_scale_tensor = - ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); } // freq_factors diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index d6bcde2fe73b3..2abb377bc6bc2 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -376,8 +376,11 @@ struct ggml_cann_rope_cache { void* theta_scale_cache = nullptr; void* sin_cache = nullptr; void* cos_cache = nullptr; - int first_layer = -1; + bool cached = false; int64_t position_length = 0; + int64_t theta_scale_length = 0; + float theta_scale = 0.0f; + float freq_scale = 0.0f; }; struct ggml_cann_tensor_cache { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 81215425618a3..b79d8f98d6814 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2247,6 +2247,10 @@ static enum ggml_status ggml_backend_cann_graph_compute( (ggml_backend_cann_context*)backend->context; ggml_cann_set_device(cann_ctx->device); release_nz_workspace(); + + // calculate rope cache for fist layer in current device. + cann_ctx->rope_cache.cached = false; + #ifdef USE_ACL_GRAPH bool use_cann_graph = true; bool cann_graph_update_required = false; From 602563c5b62c0f20bb7d8199910a60e679fef1a2 Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 29 Aug 2025 08:32:30 +0000 Subject: [PATCH 3/5] CANN: Remove trailing whitespace --- ggml/src/ggml-cann/aclnn_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index a5c4e08ee1375..4e2c68de49138 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2301,7 +2301,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclTensor* acl_theta_scale_tensor = nullptr; // cache theta scale if (src2 != nullptr || ctx.rope_cache.theta_scale_length != theta_scale_length || - // theta_scale and freq_scale should not change during the current token inference process, + // theta_scale and freq_scale should not change during the current token inference process, // so we can directly use == here instead of comparing the absolute difference. ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) { From 1321c2c4e762824dc169efe4bf1f338294ac2556 Mon Sep 17 00:00:00 2001 From: hipudding Date: Sat, 30 Aug 2025 02:12:24 +0000 Subject: [PATCH 4/5] CANN: Only cache the data that can be determined as unchanged through the parameters. --- ggml/src/ggml-cann/aclnn_ops.cpp | 55 +++++++++++++++----------------- ggml/src/ggml-cann/common.h | 10 ------ ggml/src/ggml-cann/ggml-cann.cpp | 3 -- 3 files changed, 25 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 4e2c68de49138..49be815d2df37 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2262,6 +2262,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * (dim expansion vs repeat_interleave). */ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, + void* sin_tensor_buffer, void* cos_tensor_buffer, float theta_scale, float freq_scale, float attn_factor, bool is_neox) { // int sin/cos cache, cache has different repeat method depond on @@ -2271,14 +2272,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors - if(src2 == nullptr && ctx.rope_cache.cached) { - // use cache. - return; - } - - // Other layers use cache except first layer. - ctx.rope_cache.cached = true; - int64_t theta_scale_length = src0->ne[0] / 2; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), @@ -2300,7 +2293,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, // theta_scale arange, [0,1,...,ne00/2 - 1] aclTensor* acl_theta_scale_tensor = nullptr; // cache theta scale - if (src2 != nullptr || ctx.rope_cache.theta_scale_length != theta_scale_length || + if (ctx.rope_cache.theta_scale_length != theta_scale_length || // theta_scale and freq_scale should not change during the current token inference process, // so we can directly use == here instead of comparing the absolute difference. ctx.rope_cache.theta_scale != theta_scale || @@ -2342,27 +2335,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); } + ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); // freq_factors if (src2) { + freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float_t)); + void* freq_fac_res_ptr = freq_fac_res_allocator.get(); aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( src2->data, ggml_cann_type_mapping(src2->type), ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); - ggml_cann_release_resources(ctx, acl_freq_factors_tensor); - } - - // init sin_repeat && cos_repeat, one token just init in 0 layer - if (position_length > ctx.rope_cache.position_length) { - ctx.rope_cache.position_length = position_length; - if (ctx.rope_cache.sin_cache != nullptr) { - ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache)); - } - if (ctx.rope_cache.cos_cache != nullptr) { - ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache)); - } - int64_t repeat_theta_length = theta_scale_length * position_length * 2; - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); + aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor( + freq_fac_res_ptr, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor); + std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); + ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor); } // position @@ -2412,10 +2398,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_repeat_tensor = - ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_repeat_tensor = - ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); // repeat @@ -2457,6 +2443,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input + ggml_tensor* src1 = dst->src[1]; // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -2489,8 +2476,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + // sin/cos tensor length. + int64_t repeat_theta_length = src0->ne[0] * src1->ne[0]; + ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float)); + ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float)); + void *sin_tensor_buffer = sin_tensor_allocator.get(); + void *cos_tensor_buffer = cos_tensor_allocator.get(); + // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox); + aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, + theta_scale, freq_scale, attn_factor, is_neox); int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; size_t sin_reshape_nb[GGML_MAX_DIMS]; @@ -2499,10 +2494,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_src = ggml_cann_create_tensor(src0); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 2abb377bc6bc2..6e0745d6202f8 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -365,19 +365,9 @@ struct ggml_cann_rope_cache { if(theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(theta_scale_cache)); } - if(sin_cache != nullptr) { - ACL_CHECK(aclrtFree(sin_cache)); - } - if(cos_cache != nullptr) { - ACL_CHECK(aclrtFree(cos_cache)); - } } void* theta_scale_cache = nullptr; - void* sin_cache = nullptr; - void* cos_cache = nullptr; - bool cached = false; - int64_t position_length = 0; int64_t theta_scale_length = 0; float theta_scale = 0.0f; float freq_scale = 0.0f; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b79d8f98d6814..cf450e7724ab7 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2248,9 +2248,6 @@ static enum ggml_status ggml_backend_cann_graph_compute( ggml_cann_set_device(cann_ctx->device); release_nz_workspace(); - // calculate rope cache for fist layer in current device. - cann_ctx->rope_cache.cached = false; - #ifdef USE_ACL_GRAPH bool use_cann_graph = true; bool cann_graph_update_required = false; From ce56f8091536b51864a23c97f770e396a5523f9b Mon Sep 17 00:00:00 2001 From: hipudding Date: Sat, 30 Aug 2025 02:27:32 +0000 Subject: [PATCH 5/5] CANN: Update function comment --- ggml/src/ggml-cann/aclnn_ops.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 49be815d2df37..1f1d489ffa411 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2248,18 +2248,18 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. * 6. Expand sin/cos values by repeat or repeat_interleave depending * on whether @param is_neox is enabled. - * 7. Store the computed values into persistent buffers - * (ctx.rope_cache.sin_cache / ctx.rope_cache.cos_cache). - * - * @param ctx The CANN backend context, holding memory pool, - * stream, and persistent buffers for rope init/cache. - * @param dst The destination ggml_tensor whose computation - * depends on the cached RoPE values (usually Qcur/Kcur). - * @param theta_scale Scalar exponent base for computing theta scale values. - * @param freq_scale Frequency scaling factor, applied to theta scale. - * @param attn_factor Attention scaling factor, applied to sin/cos. - * @param is_neox Whether to use Neox-style repeat strategy - * (dim expansion vs repeat_interleave). + * + * @param ctx The CANN backend context, holding memory pool, + * stream, and persistent buffers for rope init/cache. + * @param dst The destination ggml_tensor whose computation + * depends on the RoPE values (usually Qcur/Kcur). + * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values. + * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values. + * @param theta_scale Scalar exponent base for computing theta scale values. + * @param freq_scale Frequency scaling factor, applied to theta scale. + * @param attn_factor Attention scaling factor, applied to sin/cos. + * @param is_neox Whether to use Neox-style repeat strategy + * (dim expansion vs repeat_interleave). */ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, void* sin_tensor_buffer, void* cos_tensor_buffer,