Skip to content

Commit 064c90d

Browse files
authored
CANN: supports out_prod operator for F32 and F16 (#17406)
Co-authored-by: tianhao <[email protected]>
1 parent b1846f1 commit 064c90d

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <aclnnop/aclnn_exp.h>
4343
#include <aclnnop/aclnn_fill_scalar.h>
4444
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
45+
#include <aclnnop/aclnn_ger.h>
4546
#include <aclnnop/aclnn_group_norm.h>
4647
#include <aclnnop/aclnn_grouped_matmul_v3.h>
4748
#include <aclnnop/aclnn_gt_scalar.h>
@@ -3236,3 +3237,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
32363237
GGML_ABORT("Function is not implemented.");
32373238
}
32383239
}
3240+
3241+
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3242+
ggml_tensor * src0 = dst->src[0]; // weight
3243+
ggml_tensor * src1 = dst->src[1]; // input
3244+
GGML_TENSOR_BINARY_OP_LOCALS
3245+
3246+
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
3247+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
3248+
3249+
const int64_t dps2 = ne2 / ne02;
3250+
const int64_t dps3 = ne3 / ne03;
3251+
for (int64_t i3 = 0; i3 < ne3; i3++) {
3252+
for (int64_t i2 = 0; i2 < ne2; i2++) {
3253+
const int64_t i02 = i2 / dps2;
3254+
const int64_t i03 = i3 / dps3;
3255+
3256+
const int64_t i12 = i2;
3257+
const int64_t i13 = i3;
3258+
acl_tensor_ptr accumulator =
3259+
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
3260+
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
3261+
3262+
// The outer product needs to be accumulated in this dimension.
3263+
for (int64_t i1 = 0; i1 < ne11; i1++) {
3264+
acl_tensor_ptr acl_input = ggml_cann_create_tensor(
3265+
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
3266+
ggml_type_size(src0->type), src1->ne, src1->nb, 1);
3267+
3268+
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
3269+
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
3270+
ggml_type_size(src0->type), src0->ne, src0->nb, 1);
3271+
3272+
ggml_cann_pool_alloc output_allocator(ctx.pool());
3273+
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
3274+
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
3275+
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
3276+
3277+
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
3278+
float alpha_value = 1.0f;
3279+
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
3280+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
3281+
}
3282+
}
3283+
}
3284+
}
3285+
3286+
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3287+
ggml_tensor * src0 = dst->src[0];
3288+
3289+
const enum ggml_type type = src0->type;
3290+
3291+
switch (type) {
3292+
case GGML_TYPE_F32:
3293+
case GGML_TYPE_F16:
3294+
ggml_cann_out_prod_fp(ctx, dst);
3295+
break;
3296+
default:
3297+
GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD");
3298+
break;
3299+
}
3300+
}

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
11251125
} while (0)
11261126

11271127
#endif // CANN_ACLNN_OPS
1128+
1129+
/**
1130+
* @brief Performs outer product operation on two ggml tensors using the CANN backend.
1131+
*
1132+
* @details This function computes the outer product of two input tensors (src0 and src1)
1133+
* and stores the result in the destination tensor. The outer product operation is defined as:
1134+
* dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])
1135+
*
1136+
* The function supports multiple data types including F32, F16. For floating-point
1137+
* types, it uses batch matrix multiplication for efficient computation.
1138+
*
1139+
* The implementation handles 4D tensor broadcasting and batch processing automatically.
1140+
*
1141+
* @param ctx The CANN backend context for operation execution and memory management.
1142+
* @param dst The destination ggml_tensor where the outer product result will be stored.
1143+
* The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.
1144+
*
1145+
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
1146+
*/
1147+
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
18861886
case GGML_OP_FLASH_ATTN_EXT:
18871887
ggml_cann_flash_attn_ext(ctx, dst);
18881888
break;
1889+
case GGML_OP_OUT_PROD:
1890+
ggml_cann_out_prod(ctx, dst);
1891+
break;
18891892
default:
18901893
return false;
18911894
}
@@ -2563,6 +2566,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
25632566
case GGML_OP_PAD_REFLECT_1D:
25642567
case GGML_OP_COUNT_EQUAL:
25652568
return true;
2569+
case GGML_OP_OUT_PROD:
2570+
{
2571+
switch (op->src[0]->type) {
2572+
case GGML_TYPE_F16:
2573+
case GGML_TYPE_F32:
2574+
return true;
2575+
default:
2576+
return false;
2577+
}
2578+
}
25662579
case GGML_OP_CONV_TRANSPOSE_1D:
25672580
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
25682581
return (op->src[0]->ne[0] - 1) <= 255;

0 commit comments

Comments
 (0)