|
42 | 42 | #include <aclnnop/aclnn_exp.h> |
43 | 43 | #include <aclnnop/aclnn_fill_scalar.h> |
44 | 44 | #include <aclnnop/aclnn_fused_infer_attention_score_v2.h> |
| 45 | +#include <aclnnop/aclnn_ger.h> |
45 | 46 | #include <aclnnop/aclnn_group_norm.h> |
46 | 47 | #include <aclnnop/aclnn_grouped_matmul_v3.h> |
47 | 48 | #include <aclnnop/aclnn_gt_scalar.h> |
@@ -3236,3 +3237,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst |
3236 | 3237 | GGML_ABORT("Function is not implemented."); |
3237 | 3238 | } |
3238 | 3239 | } |
| 3240 | + |
| 3241 | +static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { |
| 3242 | + ggml_tensor * src0 = dst->src[0]; // weight |
| 3243 | + ggml_tensor * src1 = dst->src[1]; // input |
| 3244 | + GGML_TENSOR_BINARY_OP_LOCALS |
| 3245 | + |
| 3246 | + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); |
| 3247 | + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); |
| 3248 | + |
| 3249 | + const int64_t dps2 = ne2 / ne02; |
| 3250 | + const int64_t dps3 = ne3 / ne03; |
| 3251 | + for (int64_t i3 = 0; i3 < ne3; i3++) { |
| 3252 | + for (int64_t i2 = 0; i2 < ne2; i2++) { |
| 3253 | + const int64_t i02 = i2 / dps2; |
| 3254 | + const int64_t i03 = i3 / dps3; |
| 3255 | + |
| 3256 | + const int64_t i12 = i2; |
| 3257 | + const int64_t i13 = i3; |
| 3258 | + acl_tensor_ptr accumulator = |
| 3259 | + ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), |
| 3260 | + ggml_type_size(dst->type), dst->ne, dst->nb, 2); |
| 3261 | + |
| 3262 | + // The outer product needs to be accumulated in this dimension. |
| 3263 | + for (int64_t i1 = 0; i1 < ne11; i1++) { |
| 3264 | + acl_tensor_ptr acl_input = ggml_cann_create_tensor( |
| 3265 | + (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), |
| 3266 | + ggml_type_size(src0->type), src1->ne, src1->nb, 1); |
| 3267 | + |
| 3268 | + acl_tensor_ptr acl_weight = ggml_cann_create_tensor( |
| 3269 | + (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), |
| 3270 | + ggml_type_size(src0->type), src0->ne, src0->nb, 1); |
| 3271 | + |
| 3272 | + ggml_cann_pool_alloc output_allocator(ctx.pool()); |
| 3273 | + void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); |
| 3274 | + acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), |
| 3275 | + ggml_type_size(dst->type), dst->ne, dst->nb, 2); |
| 3276 | + |
| 3277 | + GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); |
| 3278 | + float alpha_value = 1.0f; |
| 3279 | + aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); |
| 3280 | + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); |
| 3281 | + } |
| 3282 | + } |
| 3283 | + } |
| 3284 | +} |
| 3285 | + |
| 3286 | +void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) { |
| 3287 | + ggml_tensor * src0 = dst->src[0]; |
| 3288 | + |
| 3289 | + const enum ggml_type type = src0->type; |
| 3290 | + |
| 3291 | + switch (type) { |
| 3292 | + case GGML_TYPE_F32: |
| 3293 | + case GGML_TYPE_F16: |
| 3294 | + ggml_cann_out_prod_fp(ctx, dst); |
| 3295 | + break; |
| 3296 | + default: |
| 3297 | + GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD"); |
| 3298 | + break; |
| 3299 | + } |
| 3300 | +} |
0 commit comments