Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
#include <aclnnop/aclnn_zero.h>
#include <aclnnop/aclnn_index_copy.h>
#include <aclnnop/aclnn_index_select.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_threshold.h>
#include <float.h>

#include <cmath>
Expand Down Expand Up @@ -2263,6 +2265,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
*/
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
void* sin_tensor_buffer, void* cos_tensor_buffer,
float* corr_dims, float ext_factor,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
Expand Down Expand Up @@ -2318,16 +2321,60 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);

ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
aclTensor* acl_yarn_ramp_tensor = nullptr;
if (ext_factor != 0) {
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);

GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);

// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);

ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
}

// power
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
acl_theta_scale_tensor);

// freq_scale
if (freq_scale != 1) {
if (ext_factor != 0) {
aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
}
ggml_cann_release_resources(ctx, acl_theta_scale);

ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
} else {
// use cache
acl_theta_scale_tensor =
Expand Down Expand Up @@ -2385,6 +2432,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);

if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}

// attn_factor
if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
Expand Down Expand Up @@ -2465,8 +2516,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0);
GGML_ASSERT(n_dims % 2 == 0);
// TODO: ext_factor != 0
GGML_ASSERT(ext_factor == 0);

const float theta_scale = powf(freq_base, -2.0f / n_dims);

Expand All @@ -2484,7 +2533,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void *cos_tensor_buffer = cos_tensor_allocator.get();

// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
theta_scale, freq_scale, attn_factor, is_neox);

int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
Expand Down
16 changes: 0 additions & 16 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2401,16 +2401,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
case GGML_OP_ROPE: {
// TODO: with ops-test v == 1
float ext_factor = 0.0f;
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
// TODO: n_dims <= ne0
if (op->src[0]->ne[0] != op->op_params[1]) {
return false;
}
// TODO: ext_factor != 0
if (ext_factor != 0) {
return false;
}

const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
Expand All @@ -2420,9 +2414,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return false;
}

if(!ggml_is_contiguous(op->src[0])){
return false;
}
return true;
}
case GGML_OP_UPSCALE: {
Expand Down Expand Up @@ -2523,13 +2514,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
Expand Down
Loading