@@ -250,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
250250
251251static const int8_t kvalues_iq4nl[16 ] = {-127 , -104 , -83 , -65 , -49 , -35 , -22 , -10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 };
252252
253- static void quantize_q8_0_4x4 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
253+ static void ggml_quantize_mat_q8_0_4x4 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
254254 assert (QK8_0 == 32 );
255255 assert (k % QK8_0 == 0 );
256256 const int nb = k / QK8_0;
@@ -344,7 +344,7 @@ static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRIC
344344#endif
345345}
346346
347- static void quantize_q8_0_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
347+ static void ggml_quantize_mat_q8_0_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
348348 assert (QK8_0 == 32 );
349349 assert (k % QK8_0 == 0 );
350350 const int nb = k / QK8_0;
@@ -559,7 +559,7 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
559559#endif
560560}
561561
562- static void quantize_q8_K_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
562+ static void ggml_quantize_mat_q8_K_4x8 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
563563 assert (QK_K == 256 );
564564 assert (k % QK_K == 0 );
565565 const int nb = k / QK_K;
@@ -823,26 +823,25 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
823823#endif
824824}
825825
826- static void quantize_mat_q8_0 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
826+ template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
827+ void ggml_quantize_mat_t (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
828+
829+ template <> void ggml_quantize_mat_t <4 , GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
827830 assert (nrow == 4 );
828831 UNUSED (nrow);
829- if (blck_size_interleave == 4 ) {
830- quantize_q8_0_4x4 (x, vy, n_per_row);
831- } else if (blck_size_interleave == 8 ) {
832- quantize_q8_0_4x8 (x, vy, n_per_row);
833- } else {
834- assert (false );
835- }
832+ ggml_quantize_mat_q8_0_4x4 (x, vy, n_per_row);
836833}
837834
838- static void quantize_mat_q8_K (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave ) {
835+ template <> void ggml_quantize_mat_t < 8 , GGML_TYPE_Q8_0> (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
839836 assert (nrow == 4 );
840837 UNUSED (nrow);
841- if (blck_size_interleave == 8 ) {
842- quantize_q8_K_4x8 (x, vy, n_per_row);
843- } else {
844- assert (false );
845- }
838+ ggml_quantize_mat_q8_0_4x8 (x, vy, n_per_row);
839+ }
840+
841+ template <> void ggml_quantize_mat_t <8 , GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
842+ assert (nrow == 4 );
843+ UNUSED (nrow);
844+ ggml_quantize_mat_q8_K_4x8 (x, vy, n_per_row);
846845}
847846
848847static void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -5276,50 +5275,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void *
52765275// }
52775276
52785277// gemv
5279- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5278+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE >
52805279void gemv (int , float *, size_t , const void *, const void *, int , int );
52815280
5282- template <> void gemv<block_q4_0, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5281+ template <> void gemv<block_q4_0, 4 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52835282 ggml_gemv_q4_0_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
52845283}
52855284
5286- template <> void gemv<block_q4_0, 8 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5285+ template <> void gemv<block_q4_0, 8 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52875286 ggml_gemv_q4_0_4x8_q8_0 (n, s, bs, vx, vy, nr, nc);
52885287}
52895288
5290- template <> void gemv<block_q4_0, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5289+ template <> void gemv<block_q4_0, 8 , 8 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52915290 ggml_gemv_q4_0_8x8_q8_0 (n, s, bs, vx, vy, nr, nc);
52925291}
52935292
5294- template <> void gemv<block_q4_K, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5293+ template <> void gemv<block_q4_K, 8 , 8 , GGML_TYPE_Q8_K >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52955294 ggml_gemv_q4_K_8x8_q8_K (n, s, bs, vx, vy, nr, nc);
52965295}
52975296
5298- template <> void gemv<block_iq4_nl, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5297+ template <> void gemv<block_iq4_nl, 4 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
52995298 ggml_gemv_iq4_nl_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
53005299}
53015300
53025301// gemm
5303- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5302+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE >
53045303void gemm (int , float *, size_t , const void *, const void *, int , int );
53055304
5306- template <> void gemm<block_q4_0, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5305+ template <> void gemm<block_q4_0, 4 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53075306 ggml_gemm_q4_0_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
53085307}
53095308
5310- template <> void gemm<block_q4_0, 8 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5309+ template <> void gemm<block_q4_0, 8 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53115310 ggml_gemm_q4_0_4x8_q8_0 (n, s, bs, vx, vy, nr, nc);
53125311}
53135312
5314- template <> void gemm<block_q4_0, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5313+ template <> void gemm<block_q4_0, 8 , 8 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53155314 ggml_gemm_q4_0_8x8_q8_0 (n, s, bs, vx, vy, nr, nc);
53165315}
53175316
5318- template <> void gemm<block_q4_K, 8 , 8 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5317+ template <> void gemm<block_q4_K, 8 , 8 , GGML_TYPE_Q8_K >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53195318 ggml_gemm_q4_K_8x8_q8_K (n, s, bs, vx, vy, nr, nc);
53205319}
53215320
5322- template <> void gemm<block_iq4_nl, 4 , 4 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5321+ template <> void gemm<block_iq4_nl, 4 , 4 , GGML_TYPE_Q8_0 >(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
53235322 ggml_gemm_iq4_nl_4x4_q8_0 (n, s, bs, vx, vy, nr, nc);
53245323}
53255324
@@ -5350,15 +5349,15 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
53505349
53515350 bool compute_forward (struct ggml_compute_params * params, struct ggml_tensor * op) override {
53525351 switch (op->op ) {
5353- case GGML_OP_MUL_MAT:
5354- forward_mul_mat (params, op);
5355- return true ;
5356- case GGML_OP_MUL_MAT_ID:
5357- forward_mul_mat_id (params, op);
5358- return true ;
5359- default :
5360- // GGML_ABORT("fatal error");
5361- break ;
5352+ case GGML_OP_MUL_MAT:
5353+ forward_mul_mat (params, op);
5354+ return true ;
5355+ case GGML_OP_MUL_MAT_ID:
5356+ forward_mul_mat_id (params, op);
5357+ return true ;
5358+ default :
5359+ // GGML_ABORT("fatal error");
5360+ break ;
53625361 }
53635362 return false ;
53645363 }
@@ -5397,18 +5396,10 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
53975396 const ggml_from_float_t from_float = ggml_get_type_traits_cpu (PARAM_TYPE)->from_float ;
53985397
53995398 int64_t i11_processed = 0 ;
5400- if (PARAM_TYPE == GGML_TYPE_Q8_K) {
5401- for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
5402- quantize_mat_q8_K ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10,
5403- INTER_SIZE);
5404- }
5405- } else {
5406- GGML_ASSERT (PARAM_TYPE == GGML_TYPE_Q8_0);
5407- for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
5408- quantize_mat_q8_0 ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10,
5409- INTER_SIZE);
5410- }
5399+ for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
5400+ ggml_quantize_mat_t <INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10);
54115401 }
5402+
54125403 i11_processed = ne11 - ne11 % 4 ;
54135404 for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
54145405 from_float ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
@@ -5428,15 +5419,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54285419
54295420 // If there are more than three rows in src1, use gemm; otherwise, use gemv.
54305421 if (ne11 > 3 ) {
5431- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data ) + src0_start, ne01,
5432- (const char *) src0->data + src0_start * nb01,
5433- (const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start);
5422+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5423+ (float *) ((char *) dst->data ) + src0_start, ne01,
5424+ (const char *) src0->data + src0_start * nb01,
5425+ (const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start);
54345426 }
54355427 for (int iter = ne11 - ne11 % 4 ; iter < ne11; iter++) {
5436- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5437- (const char *) src0->data + src0_start * nb01,
5438- (const char *) src1_wdata + (src1_col_stride * iter), 1 ,
5439- src0_end - src0_start);
5428+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5429+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5430+ (const char *) src0->data + src0_start * nb01,
5431+ (const char *) src1_wdata + (src1_col_stride * iter), 1 ,
5432+ src0_end - src0_start);
54405433 }
54415434 }
54425435
@@ -5485,9 +5478,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
54855478 GGML_ASSERT (params->wsize >= (GGML_PAD (nbw3, sizeof (int64_t )) + n_as * sizeof (int64_t ) +
54865479 n_as * ne12 * sizeof (mmid_row_mapping)));
54875480
5488- auto wdata = (char *) params->wdata ;
5489- auto wdata_src1_end = (char *) wdata + GGML_PAD (nbw3, sizeof (int64_t ));
5490- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5481+ auto * wdata = (char *) params->wdata ;
5482+ auto * wdata_src1_end = (char *) wdata + GGML_PAD (nbw3, sizeof (int64_t ));
5483+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
54915484
54925485 struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
54935486
@@ -5530,7 +5523,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
55305523 continue ;
55315524 }
55325525
5533- auto src0_cur = (const char *) src0->data + cur_a*nb02;
5526+ const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
55345527
55355528 // const int64_t nr0 = ne01; // src0 rows
55365529 const int64_t nr1 = cne1; // src1 rows
@@ -5541,7 +5534,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
55415534 src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
55425535 src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
55435536
5544- if (src0_cur_start >= src0_cur_end) return ;
5537+ if (src0_cur_start >= src0_cur_end) {
5538+ return ;
5539+ }
55455540
55465541 for (int ir1 = 0 ; ir1 < nr1; ir1++) {
55475542 struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW (cur_a, ir1);
@@ -5554,11 +5549,11 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
55545549 const int64_t i1 = id; // selected expert index
55555550 const int64_t i2 = i12; // row
55565551
5557- auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
5552+ const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
55585553
5559- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
5560- ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
5561- ne01, src0_cur + src0_cur_start * nb01,
5554+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5555+ (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01 ,
5556+ src0_cur + src0_cur_start * nb01,
55625557 src1_col, 1 , src0_cur_end - src0_cur_start);
55635558 }
55645559 }
0 commit comments