@@ -908,14 +908,6 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
908908void ggml_cann_dup (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
909909 ggml_tensor* src0 = dst->src [0 ];
910910
911- if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) {
912- GGML_ABORT (" Only support src type in [F32, F16]" );
913- }
914-
915- if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
916- GGML_ABORT (" Only support dst type in [F32, F16]" );
917- }
918-
919911 aclTensor* acl_src = ggml_cann_create_tensor (src0);
920912 aclTensor* acl_dst = ggml_cann_create_tensor (dst);
921913 if (ggml_are_same_shape (src0, dst)) {
@@ -2388,6 +2380,68 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
23882380 ACL_CHECK (aclDestroyTensor (src_trans_tensor));
23892381 break ;
23902382 }
2383+ case GGML_TYPE_Q8_0: {
2384+ // add 1 dim for bcast mul.
2385+ size_t weight_nb[GGML_MAX_DIMS + 1 ], scale_nb[GGML_MAX_DIMS + 1 ],
2386+ dequant_nb[GGML_MAX_DIMS + 1 ];
2387+ int64_t weight_ne[GGML_MAX_DIMS + 1 ], scale_ne[GGML_MAX_DIMS + 1 ],
2388+ *dequant_ne;
2389+ int64_t scale_offset = 0 ;
2390+
2391+ // [3,4,5,64] -> [3,4,5,2,32]
2392+ weight_ne[0 ] = QK8_0;
2393+ weight_ne[1 ] = src0->ne [0 ] / QK8_0;
2394+ weight_nb[0 ] = sizeof (int8_t );
2395+ weight_nb[1 ] = weight_nb[0 ] * weight_ne[0 ];
2396+ for (int i = 2 ; i < GGML_MAX_DIMS + 1 ; i++) {
2397+ weight_ne[i] = src0->ne [i - 1 ];
2398+ weight_nb[i] = weight_nb[i - 1 ] * weight_ne[i - 1 ];
2399+ }
2400+
2401+ // [3,4,5,64] -> [3,4,5,2,1]
2402+ scale_ne[0 ] = 1 ;
2403+ scale_ne[1 ] = src0->ne [0 ] / QK8_0;
2404+ scale_nb[0 ] = sizeof (uint16_t );
2405+ scale_nb[1 ] = scale_nb[0 ] * scale_ne[0 ];
2406+ for (int i = 2 ; i < GGML_MAX_DIMS + 1 ; i++) {
2407+ scale_ne[i] = src0->ne [i - 1 ];
2408+ scale_nb[i] = scale_nb[i - 1 ] * scale_ne[i - 1 ];
2409+ }
2410+
2411+ // [3,4,5,64] -> [3,4,5,2,32]
2412+ dequant_ne = weight_ne;
2413+ dequant_nb[0 ] = sizeof (float_t );
2414+ for (int i = 1 ; i < GGML_MAX_DIMS + 1 ; i++) {
2415+ dequant_nb[i] = dequant_nb[i - 1 ] * dequant_ne[i - 1 ];
2416+ }
2417+
2418+ scale_offset = ggml_nelements (src0) * sizeof (int8_t );
2419+ ggml_cann_pool_alloc dequant_buffer_allocator (
2420+ ctx.pool (), ggml_nelements (src0) * sizeof (float_t ));
2421+
2422+ aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
2423+ src0->data , ACL_INT8, sizeof (int8_t ), weight_ne, weight_nb,
2424+ GGML_MAX_DIMS + 1 );
2425+ aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
2426+ src0->data , ACL_FLOAT16, sizeof (float16_t ), scale_ne, scale_nb,
2427+ GGML_MAX_DIMS + 1 , ACL_FORMAT_ND, scale_offset);
2428+ aclTensor* dequant_tensor = ggml_cann_create_tensor (
2429+ dequant_buffer_allocator.get (), ACL_FLOAT, sizeof (float_t ),
2430+ dequant_ne, dequant_nb, GGML_MAX_DIMS + 1 );
2431+
2432+ aclnn_mul (ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
2433+ dequant_nb[0 ] = sizeof (float_t );
2434+ dequant_ne = src0->ne ;
2435+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2436+ dequant_nb[i] = dequant_nb[i - 1 ] * src0->ne [i - 1 ];
2437+ }
2438+
2439+ ggml_cann_embedding_4d (ctx, dequant_buffer_allocator.get (),
2440+ dequant_ne, dequant_nb, src1, dst);
2441+
2442+ ACL_CHECK (aclDestroyTensor (dequant_tensor));
2443+ break ;
2444+ }
23912445 default :
23922446 GGML_ABORT (" Unsupported tensor type for GGML_OP_GET_ROWS" );
23932447 break ;
0 commit comments