Skip to content

Commit 3dc7397

Browse files
authored
CANN: fix RoPE cache issue on multi-device (ggml-org#15629)
* CANN: fix RoPE cache issue on multi-device RoPE cache only needs to be computed once per token. However, in multi-device scenarios, not every device starts computation from layer 0, which may lead to unallocated memory issues and precision errors. This commit records the first layer of each device to avoid the above issues. * CANN: Optimize first-layer detection method * CANN: Remove trailing whitespace * CANN: Only cache the data that can be determined as unchanged through the parameters. * CANN: Update function comment
1 parent e92d53b commit 3dc7397

File tree

3 files changed

+103
-94
lines changed

3 files changed

+103
-94
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 74 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
964964
}
965965
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
966966
ctx,
967-
&ctx.f32_one_cache,
968-
ctx.f32_one_cache_element,
967+
&ctx.rms_norm_one_tensor_cache.cache,
968+
ctx.rms_norm_one_tensor_cache.size,
969969
src->ne,
970970
acl_gamma_nb,
971971
1, // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
980980
}
981981
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
982982
ctx,
983-
&ctx.f32_zero_cache,
984-
ctx.f32_zero_cache_element,
983+
&ctx.rms_norm_zero_tensor_cache.cache,
984+
ctx.rms_norm_zero_tensor_cache.size,
985985
src->ne,
986986
acl_rstd_nb,
987987
GGML_MAX_DIMS,
@@ -2248,43 +2248,31 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22482248
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
22492249
* 6. Expand sin/cos values by repeat or repeat_interleave depending
22502250
* on whether @param is_neox is enabled.
2251-
* 7. Store the computed values into persistent buffers
2252-
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
2253-
*
2254-
* @param ctx The CANN backend context, holding memory pool,
2255-
* stream, and persistent buffers for rope init/cache.
2256-
* @param dst The destination ggml_tensor whose computation
2257-
* depends on the cached RoPE values (usually Qcur/Kcur).
2258-
* @param theta_scale Scalar exponent base for computing theta scale values.
2259-
* @param freq_scale Frequency scaling factor, applied to theta scale.
2260-
* @param attn_factor Attention scaling factor, applied to sin/cos.
2261-
* @param is_neox Whether to use Neox-style repeat strategy
2262-
* (dim expansion vs repeat_interleave).
2251+
*
2252+
* @param ctx The CANN backend context, holding memory pool,
2253+
* stream, and persistent buffers for rope init/cache.
2254+
* @param dst The destination ggml_tensor whose computation
2255+
* depends on the RoPE values (usually Qcur/Kcur).
2256+
* @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
2257+
* @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
2258+
* @param theta_scale Scalar exponent base for computing theta scale values.
2259+
* @param freq_scale Frequency scaling factor, applied to theta scale.
2260+
* @param attn_factor Attention scaling factor, applied to sin/cos.
2261+
* @param is_neox Whether to use Neox-style repeat strategy
2262+
* (dim expansion vs repeat_interleave).
22632263
*/
22642264
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2265+
void* sin_tensor_buffer, void* cos_tensor_buffer,
22652266
float theta_scale, float freq_scale,
22662267
float attn_factor, bool is_neox) {
22672268
// int sin/cos cache, cache has different repeat method depond on
22682269
// @param.is_neox
2269-
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
2270-
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
2271-
2272-
// used for accuracy testing
2273-
bool is_attention = is_q || is_k;
2274-
2275-
// just compute in first layer in attention
2276-
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
2277-
if(is_attention && !is_fisrt_layer) {
2278-
return;
2279-
}
22802270

22812271
ggml_tensor* src0 = dst->src[0]; // input
22822272
ggml_tensor* src1 = dst->src[1]; // position
22832273
ggml_tensor* src2 = dst->src[2]; // freq_factors
22842274

2285-
GGML_TENSOR_BINARY_OP_LOCALS
2286-
2287-
int64_t theta_scale_length = ne00 / 2;
2275+
int64_t theta_scale_length = src0->ne[0] / 2;
22882276
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
22892277
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
22902278
theta_scale_length * sizeof(float_t)};
@@ -2302,21 +2290,32 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23022290
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
23032291
}
23042292

2305-
// init theta scale, just one time
2306-
if(ctx.rope_init_ptr == nullptr || !is_attention) {
2307-
// theta_scale arange, [0,1,...,ne00/2 - 1]
2308-
if(ctx.rope_init_ptr != nullptr){
2309-
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
2293+
// theta_scale arange, [0,1,...,ne00/2 - 1]
2294+
aclTensor* acl_theta_scale_tensor = nullptr;
2295+
// cache theta scale
2296+
if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
2297+
// theta_scale and freq_scale should not change during the current token inference process,
2298+
// so we can directly use == here instead of comparing the absolute difference.
2299+
ctx.rope_cache.theta_scale != theta_scale ||
2300+
ctx.rope_cache.freq_scale != freq_scale) {
2301+
2302+
ctx.rope_cache.theta_scale_length = theta_scale_length;
2303+
ctx.rope_cache.theta_scale = theta_scale;
2304+
ctx.rope_cache.freq_scale = freq_scale;
2305+
2306+
if (ctx.rope_cache.theta_scale_cache != nullptr) {
2307+
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
23102308
}
2311-
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2309+
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
23122310

2313-
aclTensor* acl_theta_scale_tensor =
2314-
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
2311+
acl_theta_scale_tensor =
2312+
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
23152313
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2314+
23162315
float start = 0;
23172316
float step = 1;
2318-
float stop = ne00 / 2;
2319-
float n_elements = ne00 / 2;
2317+
float stop = theta_scale_length;
2318+
float n_elements = theta_scale_length;
23202319
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
23212320

23222321
// power
@@ -2328,35 +2327,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23282327
if (freq_scale != 1) {
23292328
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
23302329
}
2331-
2332-
// freq_factors
2333-
if (src2) {
2334-
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2335-
src2->data, ggml_cann_type_mapping(src2->type),
2336-
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2337-
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2338-
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
2339-
}
2340-
// release
2341-
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
2330+
ggml_cann_release_resources(ctx, acl_theta_scale);
2331+
} else {
2332+
// use cache
2333+
acl_theta_scale_tensor =
2334+
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
2335+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23422336
}
23432337

2344-
// init sin_repeat && cos_repeat, one token just init in 0 layer
2345-
if(position_length > ctx.max_prompt_length) {
2346-
ctx.max_prompt_length = position_length;
2347-
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
2348-
if(ctx.rope_sin_ptr != nullptr) {
2349-
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
2350-
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
2351-
}
2352-
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2353-
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2338+
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
2339+
// freq_factors
2340+
if (src2) {
2341+
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float_t));
2342+
void* freq_fac_res_ptr = freq_fac_res_allocator.get();
2343+
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2344+
src2->data, ggml_cann_type_mapping(src2->type),
2345+
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2346+
aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
2347+
freq_fac_res_ptr, ACL_FLOAT, sizeof(float_t),
2348+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2349+
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
2350+
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
2351+
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
23542352
}
23552353

2356-
aclTensor* acl_theta_scale_tensor =
2357-
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
2358-
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2359-
23602354
// position
23612355
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
23622356
src1->data, ggml_cann_type_mapping(src1->type),
@@ -2397,17 +2391,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23972391
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
23982392
}
23992393

2400-
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
2394+
int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
24012395
size_t sin_reshape_nb[GGML_MAX_DIMS];
24022396
sin_reshape_nb[0] = sizeof(float_t);
24032397
for (int i = 1; i < GGML_MAX_DIMS; i++) {
24042398
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
24052399
}
24062400
aclTensor* acl_sin_repeat_tensor =
2407-
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
2401+
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
24082402
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24092403
aclTensor* acl_cos_repeat_tensor =
2410-
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
2404+
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
24112405
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24122406

24132407
// repeat
@@ -2449,6 +2443,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24492443
// TODO: use ascendc
24502444
// Only test with LLAMA model.
24512445
ggml_tensor* src0 = dst->src[0]; // input
2446+
ggml_tensor* src1 = dst->src[1];
24522447

24532448
// param
24542449
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2481,8 +2476,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24812476

24822477
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
24832478

2479+
// sin/cos tensor length.
2480+
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
2481+
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
2482+
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
2483+
void *sin_tensor_buffer = sin_tensor_allocator.get();
2484+
void *cos_tensor_buffer = cos_tensor_allocator.get();
2485+
24842486
// init ctx.rope_cos/rope_sin cache
2485-
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
2487+
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
2488+
theta_scale, freq_scale, attn_factor, is_neox);
24862489

24872490
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
24882491
size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2491,10 +2494,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24912494
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
24922495
}
24932496
aclTensor* acl_sin_reshape_tensor =
2494-
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
2497+
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
24952498
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24962499
aclTensor* acl_cos_reshape_tensor =
2497-
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
2500+
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
24982501
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24992502

25002503
aclTensor* acl_src = ggml_cann_create_tensor(src0);

ggml/src/ggml-cann/common.h

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,30 @@ struct ggml_cann_graph {
360360
};
361361
#endif // USE_ACL_GRAPH
362362

363+
struct ggml_cann_rope_cache {
364+
~ggml_cann_rope_cache() {
365+
if(theta_scale_cache != nullptr) {
366+
ACL_CHECK(aclrtFree(theta_scale_cache));
367+
}
368+
}
369+
370+
void* theta_scale_cache = nullptr;
371+
int64_t theta_scale_length = 0;
372+
float theta_scale = 0.0f;
373+
float freq_scale = 0.0f;
374+
};
375+
376+
struct ggml_cann_tensor_cache {
377+
~ggml_cann_tensor_cache() {
378+
if(cache != nullptr) {
379+
ACL_CHECK(aclrtFree(cache));
380+
}
381+
}
382+
383+
void* cache = nullptr;
384+
int64_t size = 0;
385+
};
386+
363387
/**
364388
* @brief Context for managing CANN backend operations.
365389
*/
@@ -375,15 +399,11 @@ struct ggml_backend_cann_context {
375399
cann_task_queue task_queue;
376400
bool async_mode;
377401
// Rope Cache
378-
void* rope_init_ptr = nullptr;
379-
void* rope_sin_ptr = nullptr;
380-
void* rope_cos_ptr = nullptr;
381-
int64_t max_prompt_length = 0;
402+
ggml_cann_rope_cache rope_cache;
382403
// Constant Pool
383-
void* f32_zero_cache = nullptr;
384-
void* f32_one_cache = nullptr;
385-
int64_t f32_zero_cache_element = 0;
386-
int64_t f32_one_cache_element = 0;
404+
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
405+
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
406+
387407

388408
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
389409

@@ -415,21 +435,6 @@ struct ggml_backend_cann_context {
415435
ACL_CHECK(aclrtDestroyStream(streams[i]));
416436
}
417437
}
418-
if(rope_init_ptr != nullptr) {
419-
ACL_CHECK(aclrtFree(rope_init_ptr));
420-
}
421-
if(rope_sin_ptr != nullptr) {
422-
ACL_CHECK(aclrtFree(rope_sin_ptr));
423-
}
424-
if(rope_cos_ptr != nullptr) {
425-
ACL_CHECK(aclrtFree(rope_cos_ptr));
426-
}
427-
if(f32_zero_cache != nullptr) {
428-
ACL_CHECK(aclrtFree(f32_zero_cache));
429-
}
430-
if(f32_one_cache != nullptr) {
431-
ACL_CHECK(aclrtFree(f32_one_cache));
432-
}
433438
}
434439

435440
/**

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,6 +2247,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
22472247
(ggml_backend_cann_context*)backend->context;
22482248
ggml_cann_set_device(cann_ctx->device);
22492249
release_nz_workspace();
2250+
22502251
#ifdef USE_ACL_GRAPH
22512252
bool use_cann_graph = true;
22522253
bool cann_graph_update_required = false;

0 commit comments

Comments
 (0)