Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 78 additions & 30 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,17 +516,69 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
* @param dim An array of dimension indices.
* @param dim_size The number of dimensions.
*/
static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst,
static void aclnn_reduce_sum(ggml_backend_cann_context& ctx,
ggml_tensor* dst,
int64_t* dim, size_t dim_size) {
GGML_ASSERT(dst->ne[0] == 1);
ggml_tensor* src = dst->src[0];
aclTensor* acl_src = ggml_cann_create_tensor(src);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size);

GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true,
ggml_cann_type_mapping(dst->type), acl_dst);
ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims);
bool use_fp32_accum = (dst->type != GGML_TYPE_F32);

aclTensor* acl_dst = nullptr;
aclTensor* acl_tmp = nullptr;

if (!use_fp32_accum) {
// write result directly into dst (original path)
acl_dst = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum,
acl_src, reduce_dims, true,
ggml_cann_type_mapping(dst->type), acl_dst);
ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims);
} else {
// accumulate in FP32 first, then cast to dst type
size_t nelems = ggml_nelements(dst);
size_t tmp_bytes = nelems * sizeof(float);

ggml_cann_pool_alloc tmp_buf(ctx.pool(), tmp_bytes);
void* tmp_data = tmp_buf.get();

// build temporary FP32 tensor
int64_t tmp_ne[GGML_MAX_DIMS];
size_t tmp_nb[GGML_MAX_DIMS];

for (int i = 0; i < GGML_MAX_DIMS; ++i) {
tmp_ne[i] = dst->ne[i];
}
tmp_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
tmp_nb[i] = tmp_nb[i - 1] * tmp_ne[i - 1];
}

acl_tmp = ggml_cann_create_tensor(tmp_data,
ACL_FLOAT,
sizeof(float),
tmp_ne, tmp_nb,
GGML_MAX_DIMS,
ACL_FORMAT_ND);

acl_dst = ggml_cann_create_tensor(dst);

// ReduceSum → FP32 temp
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum,
acl_src, reduce_dims, true,
aclDataType::ACL_FLOAT,
acl_tmp);

// cast FP32 → dst dtype
GGML_CANN_CALL_ACLNN_OP(ctx, Cast,
acl_tmp,
ggml_cann_type_mapping(dst->type),
acl_dst);

ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_tmp, reduce_dims);
}
}

void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
Expand Down Expand Up @@ -956,38 +1008,34 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// build gamma, one...
size_t acl_gamma_nb[GGML_MAX_DIMS];
acl_gamma_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
}
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
ctx,
&ctx.f32_one_cache,
ctx.f32_one_cache_element,
src->ne,
acl_gamma_nb,
1, // dims
1.0f // value
);

// build rstd, zero...
// gamma: same dtype as dst, filled with 1.0
const size_t gamma_elem_size = ggml_type_size(dst->type);
const aclDataType gamma_acl_dtype = ggml_cann_type_mapping(dst->type);

int64_t gamma_ne[1] = { src->ne[0] };
size_t gamma_nb[1] = { gamma_elem_size };

ggml_cann_pool_alloc gamma_allocator(ctx.pool(), gamma_ne[0] * gamma_elem_size);
void* gamma_buffer = gamma_allocator.get();

aclTensor* acl_gamma = ggml_cann_create_tensor(
gamma_buffer, gamma_acl_dtype, gamma_elem_size,
gamma_ne, gamma_nb, 1);

aclnn_fill_scalar(ctx, 1.0f, acl_gamma);

// rstd: keep FP32 as in original implementation
size_t acl_rstd_nb[GGML_MAX_DIMS];
acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
}

aclTensor* acl_rstd = get_f32_cache_acl_tensor(
ctx,
&ctx.f32_zero_cache,
ctx.f32_zero_cache_element,
src->ne,
acl_rstd_nb,
GGML_MAX_DIMS,
0.0f // value
);
ctx, &ctx.f32_zero_cache, ctx.f32_zero_cache_element,
src->ne, acl_rstd_nb, GGML_MAX_DIMS, 0.0f);

// RMSNorm
GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
}
Expand Down