Skip to content

Commit c5e77de

Browse files
committed
CANN: fix RoPE cache issue on multi-device
RoPE cache only needs to be computed once per token. However, in multi-device scenarios, not every device starts computation from layer 0, which may lead to unallocated memory issues and precision errors. This commit records the first layer of each device to avoid the above issues.
1 parent 5a0e3ef commit c5e77de

File tree

2 files changed

+100
-78
lines changed

2 files changed

+100
-78
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
964964
}
965965
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
966966
ctx,
967-
&ctx.f32_one_cache,
968-
ctx.f32_one_cache_element,
967+
&ctx.rms_norm_one_tensor_cache.cache,
968+
ctx.rms_norm_one_tensor_cache.size,
969969
src->ne,
970970
acl_gamma_nb,
971971
1, // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
980980
}
981981
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
982982
ctx,
983-
&ctx.f32_zero_cache,
984-
ctx.f32_zero_cache_element,
983+
&ctx.rms_norm_zero_tensor_cache.cache,
984+
ctx.rms_norm_zero_tensor_cache.size,
985985
src->ne,
986986
acl_rstd_nb,
987987
GGML_MAX_DIMS,
@@ -2249,7 +2249,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22492249
* 6. Expand sin/cos values by repeat or repeat_interleave depending
22502250
* on whether @param is_neox is enabled.
22512251
* 7. Store the computed values into persistent buffers
2252-
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
2252+
* (ctx.rope_cache.sin_cache / ctx.rope_cache.cos_cache).
22532253
*
22542254
* @param ctx The CANN backend context, holding memory pool,
22552255
* stream, and persistent buffers for rope init/cache.
@@ -2266,25 +2266,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22662266
float attn_factor, bool is_neox) {
22672267
// int sin/cos cache, cache has different repeat method depond on
22682268
// @param.is_neox
2269-
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
2270-
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
2271-
2272-
// used for accuracy testing
2273-
bool is_attention = is_q || is_k;
2274-
2275-
// just compute in first layer in attention
2276-
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
2277-
if(is_attention && !is_fisrt_layer) {
2278-
return;
2279-
}
22802269

22812270
ggml_tensor* src0 = dst->src[0]; // input
22822271
ggml_tensor* src1 = dst->src[1]; // position
22832272
ggml_tensor* src2 = dst->src[2]; // freq_factors
22842273

2285-
GGML_TENSOR_BINARY_OP_LOCALS
2274+
// get first layer in current device.
2275+
int layer = 0;
2276+
const char* dash = std::strchr(dst->name, '-');
2277+
if (dash) {
2278+
layer = std::strtol(dash + 1, nullptr, 10);
2279+
}
2280+
2281+
// remember the first layer.
2282+
if(ctx.rope_cache.first_layer == -1)
2283+
ctx.rope_cache.first_layer = layer;
22862284

2287-
int64_t theta_scale_length = ne00 / 2;
2285+
// only init cache when freq_factors is not null or first layer.
2286+
// dash == nullptr means we are in test-backend-ops
2287+
if(dash != nullptr && src2 == nullptr && layer != ctx.rope_cache.first_layer) {
2288+
// use cache.
2289+
return;
2290+
}
2291+
2292+
int64_t theta_scale_length = src0->ne[0] / 2;
22882293
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
22892294
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
22902295
theta_scale_length * sizeof(float_t)};
@@ -2302,21 +2307,24 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23022307
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
23032308
}
23042309

2310+
// theta_scale arange, [0,1,...,ne00/2 - 1]
2311+
aclTensor* acl_theta_scale_tensor = nullptr;
23052312
// init theta scale, just one time
2306-
if(ctx.rope_init_ptr == nullptr || !is_attention) {
2307-
// theta_scale arange, [0,1,...,ne00/2 - 1]
2308-
if(ctx.rope_init_ptr != nullptr){
2309-
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
2313+
// dash == nullptr means we are in test-backend-ops
2314+
if (ctx.rope_cache.theta_scale_cache == nullptr || dash == nullptr) {
2315+
if (ctx.rope_cache.theta_scale_cache != nullptr) {
2316+
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
23102317
}
2311-
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2318+
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2319+
2320+
acl_theta_scale_tensor =
2321+
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
2322+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23122323

2313-
aclTensor* acl_theta_scale_tensor =
2314-
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
2315-
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23162324
float start = 0;
23172325
float step = 1;
2318-
float stop = ne00 / 2;
2319-
float n_elements = ne00 / 2;
2326+
float stop = src0->ne[0] / 2;
2327+
float n_elements = src0->ne[0] / 2;
23202328
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
23212329

23222330
// power
@@ -2328,35 +2336,37 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23282336
if (freq_scale != 1) {
23292337
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
23302338
}
2339+
ggml_cann_release_resources(ctx, acl_theta_scale);
2340+
} else {
2341+
// use cache
2342+
acl_theta_scale_tensor =
2343+
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
2344+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2345+
}
23312346

2332-
// freq_factors
2333-
if (src2) {
2334-
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2335-
src2->data, ggml_cann_type_mapping(src2->type),
2336-
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2337-
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2338-
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
2339-
}
2340-
// release
2341-
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
2347+
// freq_factors
2348+
if (src2) {
2349+
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2350+
src2->data, ggml_cann_type_mapping(src2->type),
2351+
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2352+
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2353+
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
23422354
}
23432355

23442356
// init sin_repeat && cos_repeat, one token just init in 0 layer
2345-
if(position_length > ctx.max_prompt_length) {
2346-
ctx.max_prompt_length = position_length;
2347-
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
2348-
if(ctx.rope_sin_ptr != nullptr) {
2349-
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
2350-
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
2357+
if (position_length > ctx.rope_cache.position_length) {
2358+
ctx.rope_cache.position_length = position_length;
2359+
if (ctx.rope_cache.sin_cache != nullptr) {
2360+
ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));
2361+
}
2362+
if (ctx.rope_cache.cos_cache != nullptr) {
2363+
ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));
23512364
}
2352-
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2353-
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2365+
int64_t repeat_theta_length = theta_scale_length * position_length * 2;
2366+
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2367+
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
23542368
}
23552369

2356-
aclTensor* acl_theta_scale_tensor =
2357-
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
2358-
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2359-
23602370
// position
23612371
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
23622372
src1->data, ggml_cann_type_mapping(src1->type),
@@ -2397,17 +2407,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23972407
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
23982408
}
23992409

2400-
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
2410+
int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
24012411
size_t sin_reshape_nb[GGML_MAX_DIMS];
24022412
sin_reshape_nb[0] = sizeof(float_t);
24032413
for (int i = 1; i < GGML_MAX_DIMS; i++) {
24042414
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
24052415
}
24062416
aclTensor* acl_sin_repeat_tensor =
2407-
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
2417+
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t),
24082418
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24092419
aclTensor* acl_cos_repeat_tensor =
2410-
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
2420+
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t),
24112421
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24122422

24132423
// repeat
@@ -2491,10 +2501,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24912501
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
24922502
}
24932503
aclTensor* acl_sin_reshape_tensor =
2494-
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
2504+
ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float_t),
24952505
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24962506
aclTensor* acl_cos_reshape_tensor =
2497-
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
2507+
ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float_t),
24982508
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24992509

25002510
aclTensor* acl_src = ggml_cann_create_tensor(src0);

ggml/src/ggml-cann/common.h

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,37 @@ struct ggml_cann_graph {
360360
};
361361
#endif // USE_ACL_GRAPH
362362

363+
struct ggml_cann_rope_cache {
364+
~ggml_cann_rope_cache() {
365+
if(theta_scale_cache != nullptr) {
366+
ACL_CHECK(aclrtFree(theta_scale_cache));
367+
}
368+
if(sin_cache != nullptr) {
369+
ACL_CHECK(aclrtFree(sin_cache));
370+
}
371+
if(cos_cache != nullptr) {
372+
ACL_CHECK(aclrtFree(cos_cache));
373+
}
374+
}
375+
376+
void* theta_scale_cache = nullptr;
377+
void* sin_cache = nullptr;
378+
void* cos_cache = nullptr;
379+
int first_layer = -1;
380+
int64_t position_length = 0;
381+
};
382+
383+
struct ggml_cann_tensor_cache {
384+
~ggml_cann_tensor_cache() {
385+
if(cache != nullptr) {
386+
ACL_CHECK(aclrtFree(cache));
387+
}
388+
}
389+
390+
void* cache = nullptr;
391+
int64_t size = 0;
392+
};
393+
363394
/**
364395
* @brief Context for managing CANN backend operations.
365396
*/
@@ -376,15 +407,11 @@ struct ggml_backend_cann_context {
376407
bool async_mode;
377408
bool support_set_rows;
378409
// Rope Cache
379-
void* rope_init_ptr = nullptr;
380-
void* rope_sin_ptr = nullptr;
381-
void* rope_cos_ptr = nullptr;
382-
int64_t max_prompt_length = 0;
410+
ggml_cann_rope_cache rope_cache;
383411
// Constant Pool
384-
void* f32_zero_cache = nullptr;
385-
void* f32_one_cache = nullptr;
386-
int64_t f32_zero_cache_element = 0;
387-
int64_t f32_one_cache_element = 0;
412+
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
413+
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
414+
388415

389416
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
390417

@@ -424,21 +451,6 @@ struct ggml_backend_cann_context {
424451
ACL_CHECK(aclrtDestroyStream(streams[i]));
425452
}
426453
}
427-
if(rope_init_ptr != nullptr) {
428-
ACL_CHECK(aclrtFree(rope_init_ptr));
429-
}
430-
if(rope_sin_ptr != nullptr) {
431-
ACL_CHECK(aclrtFree(rope_sin_ptr));
432-
}
433-
if(rope_cos_ptr != nullptr) {
434-
ACL_CHECK(aclrtFree(rope_cos_ptr));
435-
}
436-
if(f32_zero_cache != nullptr) {
437-
ACL_CHECK(aclrtFree(f32_zero_cache));
438-
}
439-
if(f32_one_cache != nullptr) {
440-
ACL_CHECK(aclrtFree(f32_one_cache));
441-
}
442454
}
443455

444456
/**

0 commit comments

Comments
 (0)