Skip to content

Commit b9382c3

Browse files
authored
CANN: Optimize MUL_MAT_ID (ggml-org#15658)
1 parent 3dc7397 commit b9382c3

File tree

1 file changed

+30
-155
lines changed

1 file changed

+30
-155
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 30 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -2867,174 +2867,49 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28672867
*/
28682868
static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
28692869
//dst [M, K, N, 1]
2870-
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2871-
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2870+
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] -> [D, M, K, 1]
2871+
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1]
28722872
ggml_tensor * ids = dst->src[2]; //ids [K, N]
28732873

2874-
GGML_TENSOR_BINARY_OP_LOCALS
2875-
2876-
// copy index from npu to cpu
2877-
int64_t n_as = ne02; // A
2878-
int64_t n_ids = ids->ne[0]; // K
2879-
2880-
std::vector<char> ids_host(ggml_nbytes(ids));
2881-
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2882-
ACL_MEMCPY_DEVICE_TO_HOST);
2883-
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2874+
GGML_ASSERT(src0->ne[3] == 1);
2875+
GGML_ASSERT(src1->ne[3] == 1);
2876+
GGML_ASSERT(dst->ne[3] == 1);
28842877

2885-
char * src0_original = (char *) src0->data;
2886-
char * src1_original = (char *) src1->data;
2887-
char * dst_original = (char *) dst->data;
2888-
size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
2878+
int64_t batch = src1->ne[2];
2879+
GGML_ASSERT(batch == ids->ne[1]);
28892880

2890-
// src0 is F16, src1 is F32, dst is F32
2891-
ggml_cann_pool_alloc src0_cast_allocator;
2892-
if (src0->type == GGML_TYPE_F16) {
2893-
src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
2894-
void* src0_cast_buf = src0_cast_allocator.get();
2881+
ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0));
2882+
void* export_ptr = export_allocator.get();
2883+
for (int64_t i = 0; i < batch; i++) {
2884+
aclTensor *select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]);
2885+
aclTensor *export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3);
28952886

2896-
size_t cast_nb[GGML_MAX_DIMS];
2897-
cast_nb[0] = sizeof(float_t);
2898-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2899-
cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
2887+
int64_t select_export_ne[] = {src0->ne[0], src0->ne[1], ids->ne[0]};
2888+
size_t select_export_nb[3];
2889+
select_export_nb[0] = src0->nb[0];
2890+
for (int k = 1;k < 3; k++) {
2891+
select_export_nb[k] = select_export_nb[k-1] * select_export_ne[k-1];
29002892
}
29012893

2902-
aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
2903-
aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
2904-
ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
2905-
GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
2906-
ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
2894+
aclTensor *select_export = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_export_ne, select_export_nb, 3);
2895+
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight, 0, select_index, select_export);
29072896

2908-
src0_original = (char *) src0_cast_buf;
2909-
memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
2910-
}
2897+
int64_t select_transpose_ne[] = {select_export_ne[1], select_export_ne[0], select_export_ne[2]};
2898+
size_t select_transpose_nb[] = {select_export_nb[1], select_export_nb[0], select_export_nb[2]};
2899+
aclTensor *select_export_transpose = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_transpose_ne, select_transpose_nb, 3);
29112900

2912-
#ifdef ASCEND_310P
2913-
ggml_tensor src0_row = *src0;
2914-
ggml_tensor src1_row = *src1;
2915-
ggml_tensor dst_row = *dst;
2901+
int64_t active_tensor_ne[] = {src1->ne[0], 1, src1->ne[1]};
2902+
size_t active_tensor_nb[] = {src1->nb[0], src1->nb[1], src1->nb[1]};
2903+
aclTensor *active_tensor = ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]);
29162904

2917-
if (src0->type == GGML_TYPE_F16) {
2918-
src0_row.type = GGML_TYPE_F32;
2919-
}
2905+
int64_t dst_ne[] = {dst->ne[0], 1, dst->ne[1]};
2906+
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[1]};
2907+
aclTensor *acl_dst = ggml_cann_create_tensor(dst, dst_ne,dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]);
29202908

2921-
// src0_row [D, M, 1, 1] weight without permute
2922-
src0_row.ne[2] = 1;
2923-
src0_row.ne[3] = 1;
2924-
src0_row.nb[0] = ori_src0_nb[0];
2925-
src0_row.nb[1] = ori_src0_nb[1];
2926-
src0_row.nb[2] = ori_src0_nb[1];
2927-
src0_row.nb[3] = ori_src0_nb[1];
2928-
2929-
// src1_row [D, 1, 1, 1] -> input
2930-
src1_row.ne[1] = 1;
2931-
src1_row.ne[2] = 1;
2932-
src1_row.ne[3] = 1;
2933-
src1_row.nb[2] = nb11;
2934-
src1_row.nb[3] = nb11;
2909+
GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor, select_export_transpose, acl_dst, 2);
29352910

2936-
// dst_row [M, 1, 1, 1] -> out
2937-
dst_row.ne[1] = 1;
2938-
dst_row.ne[2] = 1;
2939-
dst_row.ne[3] = 1;
2940-
dst_row.nb[2] = nb1;
2941-
dst_row.nb[3] = nb1;
2942-
2943-
//create weight for one row
2944-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2945-
for (int64_t id = 0; id < n_ids; id++) {
2946-
// expert index
2947-
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2948-
GGML_ASSERT(i02 >= 0 && i02 < n_as);
2949-
2950-
// If B = 1 (broadcast), always use 0; otherwise, use id.
2951-
int64_t i11 = (ne11 == 1 ? 0 : id);
2952-
int64_t i12 = iid1;
2953-
2954-
int64_t i1 = id;
2955-
int64_t i2 = i12;
2956-
2957-
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
2958-
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2959-
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2960-
2961-
src0_row.data = src0_tmp_ptr;
2962-
src1_row.data = src1_tmp_ptr;
2963-
dst_row.data = dst_tmp_ptr;
2964-
dst_row.src[0] = &src0_row;
2965-
dst_row.src[1] = &src1_row;
2966-
2967-
ggml_cann_mul_mat(ctx, &dst_row);
2968-
}
2969-
}
2970-
return;
2971-
#endif
2972-
2973-
std::vector<aclTensor*> src0_tensor_vec;
2974-
std::vector<aclTensor*> src1_tensor_vec;
2975-
std::vector<aclTensor*> dst_tensor_vec;
2976-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2977-
for (int64_t id = 0; id < n_ids; id++) {
2978-
// src0_row [M, D] -> weight && permute
2979-
int64_t src0_ne[2] = {ne01, ne00};
2980-
size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
2981-
// src1_row [D, 1] -> input
2982-
int64_t src1_ne[2] = {ne10, 1};
2983-
size_t src1_nb[2] = {nb10, nb11};
2984-
// dst_row [M, 1] -> out
2985-
int64_t dst_ne[2] = {ne0, 1};
2986-
size_t dst_nb[2] = {nb0, nb1};
2987-
2988-
// expert index
2989-
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2990-
GGML_ASSERT(i02 >= 0 && i02 < n_as);
2991-
2992-
// If B = 1 (broadcast), always use 0; otherwise, use id.
2993-
int64_t i11 = (ne11 == 1 ? 0 : id);
2994-
int64_t i12 = iid1;
2995-
2996-
int64_t i1 = id;
2997-
int64_t i2 = i12;
2998-
2999-
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
3000-
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
3001-
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
3002-
3003-
aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
3004-
ACL_FLOAT, sizeof(float),
3005-
src0_ne, src0_nb, 2);
3006-
aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
3007-
ACL_FLOAT, sizeof(float),
3008-
src1_ne, src1_nb, 2);
3009-
aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
3010-
ACL_FLOAT, sizeof(float),
3011-
dst_ne, dst_nb, 2);
3012-
3013-
src0_tensor_vec.push_back(acl_src0);
3014-
src1_tensor_vec.push_back(acl_src1);
3015-
dst_tensor_vec.push_back(acl_dst);
3016-
}
2911+
ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, select_export_transpose);
30172912
}
3018-
3019-
size_t GROUP_SIZE = 128;
3020-
// GroupedMatmulV3 required tensor_list.size < 128
3021-
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
3022-
// split and call GroupedMatmulV3
3023-
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
3024-
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
3025-
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
3026-
std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
3027-
3028-
aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
3029-
aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
3030-
aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
3031-
3032-
GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV3, src1_tensor_list, src0_tensor_list,
3033-
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
3034-
3035-
ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
3036-
}
3037-
return;
30382913
}
30392914

30402915
/**

0 commit comments

Comments
 (0)