Skip to content
Merged
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true
BinPackArguments: false
BinPackArguments: true
BinPackParameters: false # OnePerLine
BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach
Expand Down
6 changes: 3 additions & 3 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,11 +1548,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-fa", "--flash-attn"}, "FA",
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
[](common_params & params, const std::string & value) {
if (value == "on" || value == "enabled") {
if (value == "on" || value == "enabled" || value == "1") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
} else if (value == "off" || value == "disabled") {
} else if (value == "off" || value == "disabled" || value == "0") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
} else if (value == "auto") {
} else if (value == "auto" || value == "-1") {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
} else {
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
Expand Down
4 changes: 4 additions & 0 deletions docs/backend/CANN.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,7 @@ Controls automatic cleanup of the memory pool. This option is only effective whe

Converting the matmul weight format from ND to NZ can significantly improve performance on the 310I DUO NPU.

### GGML_CANN_DISABLE_ACL_GRAPH

When this variable is set, ACL graph execution is disabled and operators are executed in an op-by-op (eager) mode.
This mode is mainly intended for debugging or for cases where the overhead of graph construction and execution is not desirable.
86 changes: 70 additions & 16 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
#include <aclnnop/aclnn_zero.h>
#include <aclnnop/aclnn_index_copy.h>
#include <aclnnop/aclnn_index_select.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_threshold.h>
#include <float.h>

#include <cmath>
Expand Down Expand Up @@ -1423,21 +1425,25 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
* @param start Starting exponent offset.
* @param stop Stopping exponent offset (exclusive).
* @param step Step size for the exponent increment.
* @param dtype Data type for slope tensor.
*/
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
float m, int64_t size, float start, float stop, float step){
float m, int64_t size, float start, float stop, float step, ggml_type dtype){
aclDataType acl_type = ggml_cann_type_mapping(dtype);
size_t type_size = ggml_type_size(dtype);

int64_t ne[] = {size};
size_t nb[] = {sizeof(uint16_t)};
size_t nb[] = {type_size};

ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size);
void* arange_buffer = arange_allocator.get();

aclTensor* arange_tensor = ggml_cann_create_tensor(
arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
arange_buffer, acl_type, type_size, ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size);

aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
slope_buffer, acl_type, type_size, ne, nb, 1);

aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);

Expand Down Expand Up @@ -1468,10 +1474,11 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu
* @param n_head Total number of attention heads.
* @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
* @param max_bias Maximum bias value for slope computation.
* @param dtype Data type for slope tensor.
*
*/
static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
void* slope_buffer, float max_bias) {
void* slope_buffer, float max_bias, ggml_type dtype) {
const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));

float m0 = powf(2.0f, -(max_bias) / n_head_log2);
Expand All @@ -1488,7 +1495,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
float step = 1;
float count = n_head_log2;
// end needs to be +1 because aclnn uses a left-closed, right-open interval.
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step, dtype);
if (n_head_log2 < n_head) {
// arange2
start = 2 * (n_head_log2 - n_head_log2) + 1;
Expand All @@ -1497,7 +1504,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
count = n_head - n_head_log2;
aclnn_get_slope_inner(
ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
m1, count, start, end + 1, step);
m1, count, start, end + 1, step, dtype);
}
}

Expand Down Expand Up @@ -1534,7 +1541,7 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
ggml_cann_pool_alloc bias_allocator(
ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
bias_buffer = bias_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32);
}

// broadcast for mask, slop and dst;
Expand Down Expand Up @@ -2263,6 +2270,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
*/
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
void* sin_tensor_buffer, void* cos_tensor_buffer,
float* corr_dims, float ext_factor,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
Expand Down Expand Up @@ -2318,16 +2326,60 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);

ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
aclTensor* acl_yarn_ramp_tensor = nullptr;
if (ext_factor != 0) {
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);

GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);

// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);

ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
}

// power
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
acl_theta_scale_tensor);

// freq_scale
if (freq_scale != 1) {
if (ext_factor != 0) {
aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
}
ggml_cann_release_resources(ctx, acl_theta_scale);

ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
} else {
// use cache
acl_theta_scale_tensor =
Expand Down Expand Up @@ -2385,6 +2437,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);

if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}

// attn_factor
if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
Expand Down Expand Up @@ -2465,8 +2521,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0);
GGML_ASSERT(n_dims % 2 == 0);
// TODO: ext_factor != 0
GGML_ASSERT(ext_factor == 0);

const float theta_scale = powf(freq_base, -2.0f / n_dims);

Expand All @@ -2484,7 +2538,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void *cos_tensor_buffer = cos_tensor_allocator.get();

// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
theta_scale, freq_scale, attn_factor, is_neox);

int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
Expand Down Expand Up @@ -3220,7 +3274,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
const int64_t n_heads = src0->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
void* slope_buffer = slope_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16);

int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS];
Expand Down
9 changes: 8 additions & 1 deletion ggml/src/ggml-cann/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ struct ggml_backend_cann_context {
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph;
bool acl_graph_mode = true;
#endif
cann_task_queue task_queue;
bool async_mode;
Expand All @@ -404,7 +405,6 @@ struct ggml_backend_cann_context {
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;


aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */

/**
Expand All @@ -419,6 +419,13 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
#ifdef USE_ACL_GRAPH
acl_graph_mode = !(parse_bool(get_env("GGML_CANN_DISABLE_ACL_GRAPH").value_or("")));
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
__func__, device,
acl_graph_mode ? "GRAPH" : "EAGER",
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
#endif
}

/**
Expand Down
20 changes: 4 additions & 16 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2252,6 +2252,10 @@ static enum ggml_status ggml_backend_cann_graph_compute(
bool use_cann_graph = true;
bool cann_graph_update_required = false;

if (!cann_ctx->acl_graph_mode) {
use_cann_graph = false;
}

if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph());
Expand Down Expand Up @@ -2401,16 +2405,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
case GGML_OP_ROPE: {
// TODO: with ops-test v == 1
float ext_factor = 0.0f;
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
// TODO: n_dims <= ne0
if (op->src[0]->ne[0] != op->op_params[1]) {
return false;
}
// TODO: ext_factor != 0
if (ext_factor != 0) {
return false;
}

const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
Expand All @@ -2420,9 +2418,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return false;
}

if(!ggml_is_contiguous(op->src[0])){
return false;
}
return true;
}
case GGML_OP_UPSCALE: {
Expand Down Expand Up @@ -2523,13 +2518,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
Expand Down
13 changes: 9 additions & 4 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2776,10 +2776,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
{
if (op->src[4]) {
return false;
}

const ggml_tensor * q = op->src[0];
const ggml_tensor * k = op->src[1];
const ggml_tensor * v = op->src[2];
Expand Down Expand Up @@ -5765,13 +5761,17 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
GGML_ASSERT(q->extra);
GGML_ASSERT(k->extra);
GGML_ASSERT(v->extra);
GGML_ASSERT(dst->extra);
if (mask) {
GGML_ASSERT(mask->extra);
}
if (sinks) {
GGML_ASSERT(sinks->extra);
}

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

Expand Down Expand Up @@ -5813,13 +5813,16 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;

cl_ulong offset_q = extra_q->offset + q->view_offs;
cl_ulong offset_k = extra_k->offset + k->view_offs;
cl_ulong offset_v = extra_v->offset + v->view_offs;
cl_ulong offset_o = extra_o->offset + dst->view_offs;
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;

const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
Expand Down Expand Up @@ -5874,6 +5877,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));

if (n_q == 1) {
const size_t wg_size = 64;
Expand Down
Loading