Skip to content

Commit c247d06

Browse files
authored
CANN: ROPE cache sin/cos repeat (ggml-org#15501)
Signed-off-by: noemotiovon <[email protected]>
1 parent 043fb27 commit c247d06

File tree

2 files changed

+138
-91
lines changed

2 files changed

+138
-91
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 120 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
12571257

12581258
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
12591259
aclTensor* acl_dst) {
1260-
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
1260+
if(acl_dst == nullptr) {
1261+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
1262+
} else {
1263+
GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
1264+
}
12611265
}
12621266

12631267
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
12641268
aclTensor* acl_dst) {
1265-
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
1269+
if(acl_dst == nullptr) {
1270+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
1271+
} else {
1272+
GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
1273+
}
12661274
}
12671275

12681276
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -2221,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22212229
ggml_cann_release_resources(ctx, acl_index, acl_value);
22222230
}
22232231

2232+
/**
2233+
* @brief Initializes and caches sine/cosine positional encoding values
2234+
* (used in RoPE, Rotary Position Embedding) for attention layers.
2235+
*
2236+
* This function computes and caches the sin/cos values of
2237+
* θ = position * theta_scale for RoPE encoding. The cache is shared
2238+
* across attention layers, and only the first attention layer will
2239+
* trigger initialization. The cache includes repeated sin/cos values
2240+
* with different repeat methods depending on the @param is_neox flag.
2241+
*
2242+
* Steps performed by this function:
2243+
* 1. Identify whether the target tensor belongs to Q/K in attention
2244+
* and restrict computation to the first layer only.
2245+
* 2. Initialize the theta scale array (arange → power → freq scaling).
2246+
* 3. Allocate sin/cos caches if the max prompt length increases.
2247+
* 4. Compute θ = position * theta_scale.
2248+
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
2249+
* 6. Expand sin/cos values by repeat or repeat_interleave depending
2250+
* on whether @param is_neox is enabled.
2251+
* 7. Store the computed values into persistent buffers
2252+
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
2253+
*
2254+
* @param ctx The CANN backend context, holding memory pool,
2255+
* stream, and persistent buffers for rope init/cache.
2256+
* @param dst The destination ggml_tensor whose computation
2257+
* depends on the cached RoPE values (usually Qcur/Kcur).
2258+
* @param theta_scale Scalar exponent base for computing theta scale values.
2259+
* @param freq_scale Frequency scaling factor, applied to theta scale.
2260+
* @param attn_factor Attention scaling factor, applied to sin/cos.
2261+
* @param is_neox Whether to use Neox-style repeat strategy
2262+
* (dim expansion vs repeat_interleave).
2263+
*/
22242264
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2225-
aclTensor* acl_cos_repeat_tensor,
2226-
aclTensor* acl_sin_repeat_tensor,
22272265
float theta_scale, float freq_scale,
22282266
float attn_factor, bool is_neox) {
22292267
// int sin/cos cache, cache has different repeat method depond on
22302268
// @param.is_neox
2269+
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
2270+
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
2271+
2272+
// used for accuracy testing
2273+
bool is_attention = is_q || is_k;
2274+
2275+
// just compute in first layer in attention
2276+
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
2277+
if(is_attention && !is_fisrt_layer) {
2278+
return;
2279+
}
22312280

22322281
ggml_tensor* src0 = dst->src[0]; // input
22332282
ggml_tensor* src1 = dst->src[1]; // position
@@ -2253,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22532302
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
22542303
}
22552304

2256-
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
2257-
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
2258-
2259-
// used for accuracy testing
2260-
bool is_attention = is_q || is_k;
2261-
2262-
if(ctx.init_ptr == nullptr || !is_attention) {
2305+
// init theta scale, just one time
2306+
if(ctx.rope_init_ptr == nullptr || !is_attention) {
22632307
// theta_scale arange, [0,1,...,ne00/2 - 1]
2264-
if(ctx.init_ptr != nullptr){
2265-
ACL_CHECK(aclrtFree(ctx.init_ptr));
2308+
if(ctx.rope_init_ptr != nullptr){
2309+
ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
22662310
}
2267-
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2311+
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
22682312

22692313
aclTensor* acl_theta_scale_tensor =
2270-
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
2314+
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
22712315
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
22722316
float start = 0;
22732317
float step = 1;
@@ -2297,74 +2341,75 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22972341
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
22982342
}
22992343

2300-
if(ctx.sin_ptr == nullptr) {
2301-
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
2302-
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2303-
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2304-
}
2344+
// init sin_repeat && cos_repeat, one token just init in 0 layer
23052345
if(position_length > ctx.max_prompt_length) {
23062346
ctx.max_prompt_length = position_length;
2307-
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
2308-
ACL_CHECK(aclrtFree(ctx.sin_ptr));
2309-
ACL_CHECK(aclrtFree(ctx.cos_ptr));
2310-
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2311-
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2347+
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
2348+
if(ctx.rope_sin_ptr != nullptr) {
2349+
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
2350+
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
2351+
}
2352+
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2353+
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
23122354
}
23132355

2314-
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
2315-
2316-
if(is_fisrt_layer || !is_attention) {
2317-
2318-
aclTensor* acl_theta_scale_tensor =
2319-
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
2356+
aclTensor* acl_theta_scale_tensor =
2357+
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
23202358
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23212359

2322-
// position
2323-
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
2324-
src1->data, ggml_cann_type_mapping(src1->type),
2325-
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
2326-
2327-
// power * position
2328-
int64_t theta_length = theta_scale_length * position_length;
2329-
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
2330-
theta_length * sizeof(float_t));
2331-
void* theta_buffer = theta_allocator.get();
2332-
2333-
aclTensor* acl_theta_tensor =
2334-
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
2335-
theta_ne, theta_nb, GGML_MAX_DIMS);
2336-
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
2337-
acl_theta_tensor);
2338-
2339-
// sin/cos
2340-
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2341-
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2342-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2343-
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
2344-
2345-
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
2346-
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2347-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2348-
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
2349-
2350-
// release
2351-
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2352-
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
2353-
}
2354-
2360+
// position
2361+
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
2362+
src1->data, ggml_cann_type_mapping(src1->type),
2363+
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
2364+
2365+
// power * position
2366+
int64_t theta_length = theta_scale_length * position_length;
2367+
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
2368+
theta_length * sizeof(float_t));
2369+
void* theta_buffer = theta_allocator.get();
2370+
2371+
aclTensor* acl_theta_tensor =
2372+
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
2373+
theta_ne, theta_nb, GGML_MAX_DIMS);
2374+
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
2375+
acl_theta_tensor);
2376+
2377+
// sin/cos
2378+
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
2379+
theta_length * sizeof(float_t));
2380+
void* sin_buffer = sin_allocator.get();
23552381
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2356-
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2357-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2382+
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2383+
GGML_MAX_DIMS, ACL_FORMAT_ND);
2384+
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
2385+
2386+
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
2387+
theta_length * sizeof(float_t));
2388+
void* cos_buffer = cos_allocator.get();
23582389
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
2359-
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2360-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2390+
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2391+
GGML_MAX_DIMS, ACL_FORMAT_ND);
2392+
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
23612393

23622394
// attn_factor
23632395
if (attn_factor != 1) {
23642396
aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
23652397
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
23662398
}
23672399

2400+
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
2401+
size_t sin_reshape_nb[GGML_MAX_DIMS];
2402+
sin_reshape_nb[0] = sizeof(float_t);
2403+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2404+
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
2405+
}
2406+
aclTensor* acl_sin_repeat_tensor =
2407+
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
2408+
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2409+
aclTensor* acl_cos_repeat_tensor =
2410+
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
2411+
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2412+
23682413
// repeat
23692414
if (is_neox) {
23702415
int64_t repeatsArray[] = {1, 1, 1, 2};
@@ -2380,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23802425
num_repeats, output_size);
23812426
}
23822427

2383-
// release
2384-
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
2428+
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2429+
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
2430+
acl_cos_repeat_tensor);
23852431
}
23862432

23872433
#ifdef __cplusplus
@@ -2435,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24352481

24362482
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
24372483

2438-
// init cos/sin cache
2439-
ggml_cann_pool_alloc sin_allocator(
2440-
ctx.pool(), ne00 * ne02 * sizeof(float_t));
2441-
ggml_cann_pool_alloc cos_allocator(
2442-
ctx.pool(), ne00 * ne02 * sizeof(float_t));
2443-
void* sin_buffer = sin_allocator.get();
2444-
void* cos_buffer = cos_allocator.get();
2484+
// init ctx.rope_cos/rope_sin cache
2485+
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
24452486

24462487
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
24472488
size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2450,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24502491
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
24512492
}
24522493
aclTensor* acl_sin_reshape_tensor =
2453-
ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
2494+
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
24542495
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24552496
aclTensor* acl_cos_reshape_tensor =
2456-
ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
2497+
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
24572498
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2458-
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
2459-
theta_scale, freq_scale, attn_factor, is_neox);
24602499

24612500
aclTensor* acl_src = ggml_cann_create_tensor(src0);
24622501
aclTensor* acl_dst = ggml_cann_create_tensor(dst);

ggml/src/ggml-cann/common.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -368,17 +368,19 @@ struct ggml_backend_cann_context {
368368
std::string name; /**< Name of the device. */
369369
std::string description; /**< Description of the device. */
370370
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
371-
void* init_ptr = nullptr;
372-
void* sin_ptr = nullptr;
373-
void* cos_ptr = nullptr;
374-
int64_t max_prompt_length = 65536;
375371
#ifdef USE_ACL_GRAPH
376372
/// Cached CANN ACL graph used for executing the current ggml computation graph.
377373
std::unique_ptr<ggml_cann_graph> cann_graph;
378374
#endif
379375
cann_task_queue task_queue;
380376
bool async_mode;
381377
bool support_set_rows;
378+
// Rope Cache
379+
void* rope_init_ptr = nullptr;
380+
void* rope_sin_ptr = nullptr;
381+
void* rope_cos_ptr = nullptr;
382+
int64_t max_prompt_length = 0;
383+
// Constant Pool
382384
void* f32_zero_cache = nullptr;
383385
void* f32_one_cache = nullptr;
384386
int64_t f32_zero_cache_element = 0;
@@ -422,14 +424,20 @@ struct ggml_backend_cann_context {
422424
ACL_CHECK(aclrtDestroyStream(streams[i]));
423425
}
424426
}
425-
if(init_ptr != nullptr) {
426-
ACL_CHECK(aclrtFree(init_ptr));
427+
if(rope_init_ptr != nullptr) {
428+
ACL_CHECK(aclrtFree(rope_init_ptr));
427429
}
428-
if(sin_ptr != nullptr) {
429-
ACL_CHECK(aclrtFree(sin_ptr));
430+
if(rope_sin_ptr != nullptr) {
431+
ACL_CHECK(aclrtFree(rope_sin_ptr));
430432
}
431-
if(cos_ptr != nullptr) {
432-
ACL_CHECK(aclrtFree(cos_ptr));
433+
if(rope_cos_ptr != nullptr) {
434+
ACL_CHECK(aclrtFree(rope_cos_ptr));
435+
}
436+
if(f32_zero_cache != nullptr) {
437+
ACL_CHECK(aclrtFree(f32_zero_cache));
438+
}
439+
if(f32_one_cache != nullptr) {
440+
ACL_CHECK(aclrtFree(f32_one_cache));
433441
}
434442
}
435443

0 commit comments

Comments
 (0)