Skip to content

Commit ef2af57

Browse files
authored
CANN: Support ext_factor in rope (ggml-org#15710)
1 parent 5d804a4 commit ef2af57

File tree

2 files changed

+55
-22
lines changed

2 files changed

+55
-22
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
#include <aclnnop/aclnn_zero.h>
7171
#include <aclnnop/aclnn_index_copy.h>
7272
#include <aclnnop/aclnn_index_select.h>
73+
#include <aclnnop/aclnn_clamp.h>
74+
#include <aclnnop/aclnn_threshold.h>
7375
#include <float.h>
7476

7577
#include <cmath>
@@ -2263,6 +2265,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22632265
*/
22642266
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22652267
void* sin_tensor_buffer, void* cos_tensor_buffer,
2268+
float* corr_dims, float ext_factor,
22662269
float theta_scale, float freq_scale,
22672270
float attn_factor, bool is_neox) {
22682271
// int sin/cos cache, cache has different repeat method depond on
@@ -2318,16 +2321,60 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23182321
float n_elements = theta_scale_length;
23192322
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
23202323

2324+
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
2325+
aclTensor* acl_yarn_ramp_tensor = nullptr;
2326+
if (ext_factor != 0) {
2327+
// -rope_yarn_ramp
2328+
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
2329+
// return MIN(1, MAX(0, y)) - 1;
2330+
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
2331+
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
2332+
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
2333+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2334+
float zero_value = 0, one_value = 1;
2335+
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
2336+
aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
2337+
aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
2338+
aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
2339+
aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
2340+
aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);
2341+
2342+
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
2343+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
2344+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
2345+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
2346+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
2347+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);
2348+
2349+
// theta_interp = freq_scale * theta_extrap;
2350+
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
2351+
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
2352+
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
2353+
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
2354+
//
2355+
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
2356+
// cache freq_scale + (freq_scale - 1) * ramp_mix
2357+
float freq_scale_1 = freq_scale - 1;
2358+
aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
2359+
aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
2360+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
2361+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);
2362+
2363+
ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
2364+
}
2365+
23212366
// power
23222367
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
23232368
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
23242369
acl_theta_scale_tensor);
23252370

2326-
// freq_scale
2327-
if (freq_scale != 1) {
2371+
if (ext_factor != 0) {
2372+
aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
2373+
} else if (freq_scale != 1) {
23282374
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
23292375
}
2330-
ggml_cann_release_resources(ctx, acl_theta_scale);
2376+
2377+
ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
23312378
} else {
23322379
// use cache
23332380
acl_theta_scale_tensor =
@@ -2385,6 +2432,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23852432
GGML_MAX_DIMS, ACL_FORMAT_ND);
23862433
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
23872434

2435+
if (ext_factor != 0) {
2436+
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
2437+
}
2438+
23882439
// attn_factor
23892440
if (attn_factor != 1) {
23902441
aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
@@ -2465,8 +2516,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24652516
// TODO: n_dims <= ne0
24662517
GGML_ASSERT(n_dims == ne0);
24672518
GGML_ASSERT(n_dims % 2 == 0);
2468-
// TODO: ext_factor != 0
2469-
GGML_ASSERT(ext_factor == 0);
24702519

24712520
const float theta_scale = powf(freq_base, -2.0f / n_dims);
24722521

@@ -2484,7 +2533,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24842533
void *cos_tensor_buffer = cos_tensor_allocator.get();
24852534

24862535
// init ctx.rope_cos/rope_sin cache
2487-
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
2536+
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
24882537
theta_scale, freq_scale, attn_factor, is_neox);
24892538

24902539
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,16 +2401,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
24012401
}
24022402
case GGML_OP_ROPE: {
24032403
// TODO: with ops-test v == 1
2404-
float ext_factor = 0.0f;
2405-
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
24062404
// TODO: n_dims <= ne0
24072405
if (op->src[0]->ne[0] != op->op_params[1]) {
24082406
return false;
24092407
}
2410-
// TODO: ext_factor != 0
2411-
if (ext_factor != 0) {
2412-
return false;
2413-
}
24142408

24152409
const int mode = ((const int32_t *) op->op_params)[2];
24162410
if (mode & GGML_ROPE_TYPE_MROPE) {
@@ -2420,9 +2414,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
24202414
return false;
24212415
}
24222416

2423-
if(!ggml_is_contiguous(op->src[0])){
2424-
return false;
2425-
}
24262417
return true;
24272418
}
24282419
case GGML_OP_UPSCALE: {
@@ -2523,13 +2514,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
25232514
// different head sizes of K and V are not supported yet
25242515
return false;
25252516
}
2526-
if (op->src[0]->ne[0] == 192) {
2527-
return false;
2528-
}
2529-
if (op->src[0]->ne[0] == 576) {
2530-
// DeepSeek MLA
2531-
return false;
2532-
}
25332517
if (op->src[0]->ne[0] % 16 != 0) {
25342518
// TODO: padding to support
25352519
return false;

0 commit comments

Comments
 (0)