Skip to content

Commit 97d5117

Browse files
authored
CANN: Add cross_entropy_loss op support (ggml-org#16886)
* update L2_NORM op support * update L2_NORM op support * remove extra whitespace * cann: update cross_entropy_loss op support * remove trailing whitespaces * rebase the latest code in the main repository and remove the l2_norm operator that already exists in another pull request. * undo the l2_norm operator deletion
1 parent a90eb94 commit 97d5117

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,92 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
477477
ggml_cann_release_resources(ctx, dims_array, p_scalar, acl_src, acl_dst, acl_div);
478478
}
479479

480+
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
481+
ggml_tensor * src0 = dst->src[0];
482+
ggml_tensor * src1 = dst->src[1];
483+
484+
const int64_t nc = src0->ne[0];
485+
const int64_t nr = ggml_nrows(src0);
486+
487+
int64_t logits_ne[] = {nc, nr};
488+
size_t logits_nb[2];
489+
logits_nb[0] = ggml_type_size(src0->type);
490+
logits_nb[1] = logits_nb[0] * logits_ne[0];
491+
aclTensor * acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
492+
493+
size_t log_softmax_type_size = sizeof(float);
494+
int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size;
495+
ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);
496+
void * log_softmax_buffer = log_softmax_allocator.get();
497+
498+
int64_t log_softmax_ne[] = {nc, nr};
499+
size_t log_softmax_nb[2];
500+
log_softmax_nb[0] = log_softmax_type_size;
501+
log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0];
502+
aclTensor * acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size, log_softmax_ne, log_softmax_nb, 2);
503+
504+
GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits, 1, acl_log_softmax);
505+
506+
int64_t labels_ne[] = {nc, nr};
507+
size_t labels_nb[2];
508+
labels_nb[0] = ggml_type_size(src1->type);
509+
labels_nb[1] = labels_nb[0] * labels_ne[0];
510+
aclTensor * acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);
511+
512+
size_t mul_type_size = sizeof(float);
513+
int64_t mul_n_bytes = nr * nc * mul_type_size;
514+
ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes);
515+
void * mul_buffer = mul_allocator.get();
516+
517+
int64_t mul_ne[] = {nc, nr};
518+
size_t mul_nb[2];
519+
mul_nb[0] = mul_type_size;
520+
mul_nb[1] = mul_nb[0] * mul_ne[0];
521+
aclTensor * acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2);
522+
523+
GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax, acl_labels, acl_mul_result);
524+
525+
size_t sum_per_sample_type_size = sizeof(float);
526+
int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size;
527+
ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);
528+
void * sum_per_sample_buffer = sum_per_sample_allocator.get();
529+
530+
int64_t sum_per_sample_ne[] = {nr};
531+
size_t sum_per_sample_nb[1];
532+
sum_per_sample_nb[0] = sum_per_sample_type_size;
533+
aclTensor * acl_sum_per_sample = ggml_cann_create_tensor(sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);
534+
535+
std::vector<int64_t> sum_dims = {1};
536+
aclIntArray * dims_array = aclCreateIntArray(sum_dims.data(), sum_dims.size());
537+
bool keep_dims = false;
538+
539+
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result, dims_array, keep_dims, ACL_FLOAT, acl_sum_per_sample);
540+
541+
size_t total_sum_type_size = sizeof(float);
542+
int64_t total_sum_n_bytes = 1 * total_sum_type_size;
543+
ggml_cann_pool_alloc total_sum_allocator(ctx.pool(), total_sum_n_bytes);
544+
void * total_sum_buffer = total_sum_allocator.get();
545+
546+
int64_t total_sum_ne[] = {1};
547+
size_t total_sum_nb[1];
548+
total_sum_nb[0] = total_sum_type_size;
549+
550+
aclTensor * acl_total_sum = ggml_cann_create_tensor(total_sum_buffer, ACL_FLOAT, total_sum_type_size, total_sum_ne, total_sum_nb, 1);
551+
552+
std::vector<int64_t> total_sum_dims = {0};
553+
aclIntArray * total_sum_dims_array = aclCreateIntArray(total_sum_dims.data(), total_sum_dims.size());
554+
555+
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample, total_sum_dims_array, keep_dims, ACL_FLOAT, acl_total_sum);
556+
557+
float value = -1.0f / static_cast<float>(nr);
558+
aclScalar * scale_factor = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
559+
aclTensor * acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);
560+
561+
GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_total_sum, scale_factor, acl_dst);
562+
563+
ggml_cann_release_resources(ctx, acl_logits, acl_log_softmax, acl_labels, acl_mul_result, acl_sum_per_sample, acl_total_sum, acl_dst, scale_factor, dims_array, total_sum_dims_array);
564+
}
565+
480566
void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
481567
ggml_tensor * src = dst->src[0];
482568

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include <aclnnop/aclnn_log.h>
4848
#include <aclnnop/aclnn_sign.h>
4949
#include <aclnnop/aclnn_norm.h>
50+
#include <aclnnop/aclnn_logsoftmax.h>
5051
#include "acl_tensor.h"
5152
#include "common.h"
5253

@@ -211,6 +212,43 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
211212
*/
212213
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
213214

215+
/**
216+
* @brief Computes the Cross Entropy Loss for a ggml tensor using the CANN
217+
* backend.
218+
*
219+
* @details This function computes the cross entropy loss between the predicted
220+
* logits and target probability distributions. The operation follows
221+
* the same computation pattern as the CPU implementation:
222+
* 1. Applies log_softmax to the logits along the class dimension
223+
* 2. Element-wise multiplication with target distributions
224+
* 3. Summation along the class dimension to get per-sample losses
225+
* 4. Global summation and scaling by -1/nr to get final loss
226+
*
227+
* The computation can be expressed as:
228+
* \f[
229+
* \text{loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \cdot \log(\text{softmax}(x_{ij}))
230+
* \f]
231+
* where \f$N\f$ is the total number of samples, \f$C\f$ is the number
232+
* of classes, \f$x\f$ are the logits, and \f$y\f$ are the target
233+
* probability distributions.
234+
*
235+
* @param ctx The CANN context used for operations.
236+
* @param dst The destination tensor where the computed loss will be stored.
237+
* This should be a scalar tensor containing the final loss value.
238+
*
239+
* @note This implementation computes cross entropy between probability
240+
* distributions, not the typical classification cross entropy that
241+
* expects class indices as targets. Both input tensors (src0 and src1)
242+
* should have the same shape and represent probability distributions
243+
* over the class dimension.
244+
* @note The function expects two source tensors:
245+
* - dst->src[0]: Logits tensor (before softmax)
246+
* - dst->src[1]: Target probability distributions tensor
247+
* @note The computation is performed using CANN backend operators including
248+
* LogSoftmax, Mul, ReduceSum, and Muls for the final scaling.
249+
*/
250+
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);
251+
214252
/**
215253
* @brief Computes the Group Normalization for a ggml tensor using the CANN
216254
* backend.

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
17801780
case GGML_OP_L2_NORM:
17811781
ggml_cann_l2_norm(ctx, dst);
17821782
break;
1783+
case GGML_OP_CROSS_ENTROPY_LOSS:
1784+
ggml_cann_cross_entropy_loss(ctx, dst);
1785+
break;
17831786
case GGML_OP_CONCAT:
17841787
ggml_cann_concat(ctx, dst);
17851788
break;
@@ -2519,6 +2522,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
25192522
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
25202523
}
25212524
case GGML_OP_L2_NORM:
2525+
case GGML_OP_CROSS_ENTROPY_LOSS:
25222526
case GGML_OP_DUP:
25232527
case GGML_OP_SUM:
25242528
case GGML_OP_IM2COL:

0 commit comments

Comments
 (0)