@@ -250,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
250250
251251static const int8_t kvalues_iq4nl[16 ] = {-127 , -104 , -83 , -65 , -49 , -35 , -22 , -10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 };
252252
253- static void quantize_q8_0_4x4 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
253+ static void ggml_quantize_mat_q8_0_4x4 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
254254 assert (QK8_0 == 32 );
255255 assert (k % QK8_0 == 0 );
256256 const int nb = k / QK8_0;
@@ -344,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC
344344#endif
345345}
346346
347- static void quantize_q8_0_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
347+ static void ggml_quantize_mat_q8_0_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
348348 assert (QK8_0 == 32 );
349349 assert (k % QK8_0 == 0 );
350350 const int nb = k / QK8_0;
@@ -559,7 +559,7 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
559559#endif
560560}
561561
562- static void quantize_q8_K_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
562+ static void ggml_quantize_mat_q8_K_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
563563 assert (QK_K == 256 );
564564 assert (k % QK_K == 0 );
565565 const int nb = k / QK_K;
@@ -811,7 +811,7 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
811811 // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
812812 for (int j = 0 ; j < QK_K * 4 ; j++) {
813813 int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
814- int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
814+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
815815 src_offset += (j % blck_size_interleave);
816816 int index = (((j & 31 ) >> 3 ) << 2 ) + ((j >> 8 ) << 4 ) + ((j >> 6 ) & 3 );
817817
@@ -823,26 +823,25 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
823823#endif
824824}
825825
826- static void quantize_mat_q8_0 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
826+ template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
827+ void ggml_quantize_mat_t (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
828+
829+ template <> void ggml_quantize_mat_t <4 , GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
827830 assert (nrow == 4 );
828831 UNUSED (nrow);
829- if (blck_size_interleave == 4 ) {
830- quantize_q8_0_4x4 (x, vy, n_per_row);
831- } else if (blck_size_interleave == 8 ) {
832- quantize_q8_0_4x8 (x, vy, n_per_row);
833- } else {
834- assert (false );
835- }
832+ ggml_quantize_mat_q8_0_4x4 (x, vy, n_per_row);
836833}
837834
838- static void quantize_mat_q8_K (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave ) {
835+ template <> void ggml_quantize_mat_t < 8 , GGML_TYPE_Q8_0> (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
839836 assert (nrow == 4 );
840837 UNUSED (nrow);
841- if (blck_size_interleave == 8 ) {
842- quantize_q8_K_4x8 (x, vy, n_per_row);
843- } else {
844- assert (false );
845- }
838+ ggml_quantize_mat_q8_0_4x8 (x, vy, n_per_row);
839+ }
840+
841+ template <> void ggml_quantize_mat_t <8 , GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
842+ assert (nrow == 4 );
843+ UNUSED (nrow);
844+ ggml_quantize_mat_q8_K_4x8 (x, vy, n_per_row);
846845}
847846
848847static void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -5276,52 +5275,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
52765275// }
52775276
52785277// gemv
5279- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5278+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE >
52805279void gemv (int , float *, size_t , const void *, const void *, int , int );
52815280
5282- template <> void gemv<block_q4_0, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5281+ template <> void gemv<block_q4_0, 4 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52835282 ggml_gemv_q4_0_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
52845283}
52855284
5286- template <> void gemv<block_q4_0, 8 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5285+ template <> void gemv<block_q4_0, 8 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52875286 ggml_gemv_q4_0_4x8_q8_0 (n, s, bs, vx, vy, nr, nc);
52885287}
52895288
5290- template <> void gemv<block_q4_0, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5289+ template <> void gemv<block_q4_0, 8 , 8 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52915290 ggml_gemv_q4_0_8x8_q8_0 (n, s, bs, vx, vy, nr, nc);
52925291}
52935292
5294- template <> void gemv<block_q4_K, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5293+ template <> void gemv<block_q4_K, 8 , 8 , GGML_TYPE_Q8_K >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52955294 ggml_gemv_q4_K_8x8_q8_K (n, s, bs, vx, vy, nr, nc);
52965295}
52975296
5298- template <>
5299- void gemv<block_iq4_nl, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5297+ template <> void gemv<block_iq4_nl, 4 , 4 , GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53005298 ggml_gemv_iq4_nl_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
53015299}
53025300
53035301// gemm
5304- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5302+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE >
53055303void gemm (int , float *, size_t , const void *, const void *, int , int );
53065304
5307- template <> void gemm<block_q4_0, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5305+ template <> void gemm<block_q4_0, 4 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53085306 ggml_gemm_q4_0_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
53095307}
53105308
5311- template <> void gemm<block_q4_0, 8 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5309+ template <> void gemm<block_q4_0, 8 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53125310 ggml_gemm_q4_0_4x8_q8_0 (n, s, bs, vx, vy, nr, nc);
53135311}
53145312
5315- template <> void gemm<block_q4_0, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5313+ template <> void gemm<block_q4_0, 8 , 8 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53165314 ggml_gemm_q4_0_8x8_q8_0 (n, s, bs, vx, vy, nr, nc);
53175315}
53185316
5319- template <> void gemm<block_q4_K, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5317+ template <> void gemm<block_q4_K, 8 , 8 , GGML_TYPE_Q8_K >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53205318 ggml_gemm_q4_K_8x8_q8_K (n, s, bs, vx, vy, nr, nc);
53215319}
53225320
5323- template <>
5324- void gemm<block_iq4_nl, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5321+ template <> void gemm<block_iq4_nl, 4 , 4 , GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53255322 ggml_gemm_iq4_nl_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
53265323}
53275324
@@ -5335,32 +5332,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
53355332 bool work_size (int /* n_threads */ , const struct ggml_tensor * op, size_t & size) override {
53365333 // not realy a GGML_TYPE_Q8_0 but same size.
53375334 switch (op->op ) {
5338- case GGML_OP_MUL_MAT:
5339- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
5340- return true ;
5341- case GGML_OP_MUL_MAT_ID:
5342- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
5343- size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
5344- size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
5345- return true ;
5346- default :
5347- // GGML_ABORT("fatal error");
5348- break ;
5335+ case GGML_OP_MUL_MAT:
5336+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
5337+ return true ;
5338+ case GGML_OP_MUL_MAT_ID:
5339+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
5340+ size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
5341+ size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
5342+ return true ;
5343+ default :
5344+ // GGML_ABORT("fatal error");
5345+ break ;
53495346 }
53505347 return false ;
53515348 }
53525349
53535350 bool compute_forward (struct ggml_compute_params * params, struct ggml_tensor * op) override {
53545351 switch (op->op ) {
5355- case GGML_OP_MUL_MAT:
5356- forward_mul_mat (params, op);
5357- return true ;
5358- case GGML_OP_MUL_MAT_ID:
5359- forward_mul_mat_id (params, op);
5360- return true ;
5361- default :
5362- // GGML_ABORT("fatal error");
5363- break ;
5352+ case GGML_OP_MUL_MAT:
5353+ forward_mul_mat (params, op);
5354+ return true ;
5355+ case GGML_OP_MUL_MAT_ID:
5356+ forward_mul_mat_id (params, op);
5357+ return true ;
5358+ default :
5359+ // GGML_ABORT("fatal error");
5360+ break ;
53645361 }
53655362 return false ;
53665363 }
@@ -5399,17 +5396,10 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
53995396 const ggml_from_float_t from_float = ggml_get_type_traits_cpu (PARAM_TYPE)->from_float ;
54005397
54015398 int64_t i11_processed = 0 ;
5402- if (PARAM_TYPE == GGML_TYPE_Q8_K) {
5403- for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
5404- quantize_mat_q8_K ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10,
5405- INTER_SIZE);
5406- }
5407- } else {
5408- for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
5409- quantize_mat_q8_0 ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10,
5410- INTER_SIZE);
5411- }
5399+ for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
5400+ ggml_quantize_mat_t <INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10);
54125401 }
5402+
54135403 i11_processed = ne11 - ne11 % 4 ;
54145404 for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
54155405 from_float ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
@@ -5422,22 +5412,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54225412 int64_t src0_start = (ith * ne01) / nth;
54235413 int64_t src0_end = ((ith + 1 ) * ne01) / nth;
54245414 src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
5425- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
5415+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
54265416 if (src0_start >= src0_end) {
54275417 return ;
54285418 }
54295419
54305420 // If there are more than three rows in src1, use gemm; otherwise, use gemv.
54315421 if (ne11 > 3 ) {
5432- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data ) + src0_start, ne01,
5433- (const char *) src0->data + src0_start * nb01,
5434- (const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start);
5422+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5423+ (float *) ((char *) dst->data ) + src0_start, ne01,
5424+ (const char *) src0->data + src0_start * nb01,
5425+ (const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start);
54355426 }
54365427 for (int iter = ne11 - ne11 % 4 ; iter < ne11; iter++) {
5437- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5438- (const char *) src0->data + src0_start * nb01,
5439- (const char *) src1_wdata + (src1_col_stride * iter), 1 ,
5440- src0_end - src0_start);
5428+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5429+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5430+ (const char *) src0->data + src0_start * nb01,
5431+ (const char *) src1_wdata + (src1_col_stride * iter), 1 ,
5432+ src0_end - src0_start);
54415433 }
54425434 }
54435435
@@ -5452,7 +5444,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54525444 const int ith = params->ith ;
54535445 const int nth = params->nth ;
54545446
5455- const ggml_from_float_t from_float = ggml_get_type_traits_cpu (GGML_TYPE_Q8_0 )->from_float ;
5447+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu (PARAM_TYPE )->from_float ;
54565448
54575449 // we don't support permuted src0 or src1
54585450 GGML_ASSERT (nb00 == ggml_type_size (src0->type ));
@@ -5474,7 +5466,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54745466 const int n_ids = ids->ne [0 ]; // n_expert_used
54755467 const int n_as = ne02; // n_expert
54765468
5477- const size_t nbw1 = ggml_row_size (GGML_TYPE_Q8_0 , ne10);
5469+ const size_t nbw1 = ggml_row_size (PARAM_TYPE , ne10);
54785470 const size_t nbw2 = nbw1*ne11;
54795471 const size_t nbw3 = nbw2*ne12;
54805472
@@ -5486,12 +5478,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54865478 GGML_ASSERT (params->wsize >= (GGML_PAD (nbw3, sizeof (int64_t )) + n_as * sizeof (int64_t ) +
54875479 n_as * ne12 * sizeof (mmid_row_mapping)));
54885480
5489- auto wdata = (char *) params->wdata ;
5490- auto wdata_src1_end = (char *) wdata + GGML_PAD (nbw3, sizeof (int64_t ));
5491- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5481+ auto * wdata = (char *) params->wdata ;
5482+ auto * wdata_src1_end = (char *) wdata + GGML_PAD (nbw3, sizeof (int64_t ));
5483+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5484+
54925485 struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
54935486
5494- // src1: float32 => block_q8_0
5487+ // src1: float32 => param type
54955488 for (int64_t i12 = 0 ; i12 < ne12; ++i12) {
54965489 for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
54975490 from_float ((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
@@ -5530,34 +5523,37 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
55305523 continue ;
55315524 }
55325525
5533- auto src0_cur = (const char *) src0->data + cur_a*nb02;
5526+ const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
55345527
55355528 // const int64_t nr0 = ne01; // src0 rows
55365529 const int64_t nr1 = cne1; // src1 rows
55375530
55385531 int64_t src0_cur_start = (ith * ne01) / nth;
55395532 int64_t src0_cur_end = ((ith + 1 ) * ne01) / nth;
5540- src0_cur_start =
5541- (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
5542- src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
55435533
5544- if (src0_cur_start >= src0_cur_end) return ;
5534+ src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
5535+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
5536+
5537+ if (src0_cur_start >= src0_cur_end) {
5538+ return ;
5539+ }
55455540
55465541 for (int ir1 = 0 ; ir1 < nr1; ir1++) {
55475542 struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW (cur_a, ir1);
5548- const int id = row_mapping.i1 ; // selected expert index
55495543
5550- const int64_t i11 = id % ne11;
5551- const int64_t i12 = row_mapping.i2 ; // row index in src1
5544+ const int id = row_mapping.i1 ; // selected expert index
5545+
5546+ const int64_t i11 = id % ne11;
5547+ const int64_t i12 = row_mapping.i2 ; // row index in src1
55525548
5553- const int64_t i1 = id; // selected expert index
5554- const int64_t i2 = i12; // row
5549+ const int64_t i1 = id; // selected expert index
5550+ const int64_t i2 = i12; // row
55555551
5556- auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
5552+ const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
55575553
5558- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
5559- ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
5560- ne01, src0_cur + src0_cur_start * nb01,
5554+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5555+ (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01 ,
5556+ src0_cur + src0_cur_start * nb01,
55615557 src1_col, 1 , src0_cur_end - src0_cur_start);
55625558 }
55635559 }
@@ -5578,7 +5574,7 @@ static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
55785574static const tensor_traits<block_q4_K, 8 , 8 , GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
55795575
55805576// instance for IQ4
5581- static const tensor_traits<block_iq4_nl, 4 , 4 , GGML_TYPE_IQ4_NL > iq4_nl_4x4_q8_0;
5577+ static const tensor_traits<block_iq4_nl, 4 , 4 , GGML_TYPE_Q8_0 > iq4_nl_4x4_q8_0;
55825578
55835579} // namespace ggml::cpu::aarch64
55845580
0 commit comments