@@ -1180,13 +1180,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
11801180 // not realy a GGML_TYPE_Q8_0 but same size.
11811181 switch (op->op ) {
11821182 case GGML_OP_MUL_MAT:
1183- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1184- return true ;
1183+ {
1184+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1185+ return true ;
1186+ }
11851187 case GGML_OP_MUL_MAT_ID:
1186- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1187- size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1188- size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
1189- return true ;
1188+ {
1189+ size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1190+ size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1191+
1192+ const int64_t ne02 = op->src [0 ]->ne [2 ]; // n_as, n_expert
1193+ const int64_t ne12 = op->src [1 ]->ne [2 ]; // n_tokens
1194+
1195+ const size_t sizeof_mmid_row_mapping = sizeof (int64_t );
1196+
1197+ size += sizeof_mmid_row_mapping*ne02*(ne12 + 1 );
1198+
1199+ return true ;
1200+ }
11901201 default :
11911202 // GGML_ABORT("fatal error");
11921203 break ;
@@ -1322,14 +1333,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
13221333 int32_t i2;
13231334 };
13241335
1325- GGML_ASSERT (params->wsize >= (GGML_PAD (nbw3, sizeof (int64_t )) + n_as * sizeof (int64_t ) +
1326- n_as * ne12 * sizeof (mmid_row_mapping)));
1336+ GGML_ASSERT (params->wsize >=
1337+ (GGML_PAD (nbw3, sizeof (int64_t )) +
1338+ n_as*(ne12 + 1 )*sizeof (mmid_row_mapping))
1339+ );
13271340
1328- auto * wdata = (char *) params->wdata ;
1329- auto * wdata_src1_end = (char *) wdata + GGML_PAD (nbw3, sizeof (int64_t ));
1330- auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1341+ auto * wdata = (char *)params->wdata ;
1342+ auto * wdata_src1_end = (char *)wdata + GGML_PAD (nbw3, sizeof (int64_t ));
13311343
1332- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1344+ // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
1345+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1346+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
13331347
13341348 // src1: float32 => param type
13351349 for (int64_t i12 = 0 ; i12 < ne12; ++i12) {
@@ -1414,15 +1428,6 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
14141428 }
14151429};
14161430
1417- // instance for Q4
1418- static const tensor_traits<block_q4_0, 4 , 4 , GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
1419- static const tensor_traits<block_q4_0, 8 , 4 , GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
1420- static const tensor_traits<block_q4_0, 8 , 8 , GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
1421- static const tensor_traits<block_q4_K, 8 , 8 , GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
1422-
1423- // instance for IQ4
1424- static const tensor_traits<block_iq4_nl, 4 , 4 , GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
1425-
14261431} // namespace ggml::cpu::repack
14271432
14281433static void flag_aarch_prepacked_quant (int type)
@@ -1435,55 +1440,65 @@ static void flag_aarch_prepacked_quant(int type)
14351440}
14361441
14371442static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type (const struct ggml_tensor * cur) {
1443+
1444+ // instance for Q4
1445+ static const ggml::cpu::repack::tensor_traits<block_q4_0, 4 , 4 , GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
1446+ static const ggml::cpu::repack::tensor_traits<block_q4_0, 8 , 4 , GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
1447+ static const ggml::cpu::repack::tensor_traits<block_q4_0, 8 , 8 , GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
1448+ static const ggml::cpu::repack::tensor_traits<block_q4_K, 8 , 8 , GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
1449+
1450+ // instance for IQ4
1451+ static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4 , 4 , GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
1452+
14381453 if (cur->type == GGML_TYPE_Q4_0) {
14391454 // we shall just use the regular avx2 handling, no repacking
14401455 if (/* ggml_cpu_has_avx2() ||*/ (ggml_cpu_has_sve () && ggml_cpu_has_matmul_int8 () && ggml_cpu_get_sve_cnt () == QK8_0)) {
14411456 if (cur->ne [1 ] % 8 == 0 ) {
1442- return &ggml::cpu::repack:: q4_0_8x8_q8_0;
1457+ return &q4_0_8x8_q8_0;
14431458 }
14441459 }
14451460 if (ggml_cpu_has_neon () && ggml_cpu_has_matmul_int8 ()) {
14461461 if (cur->ne [1 ] % 4 == 0 ) {
1447- return &ggml::cpu::repack:: q4_0_4x8_q8_0;
1462+ return &q4_0_4x8_q8_0;
14481463 }
14491464 }
14501465 if (ggml_cpu_has_neon () && ggml_cpu_has_dotprod ()) {
14511466 if (cur->ne [1 ] % 4 == 0 ) {
1452- return &ggml::cpu::repack:: q4_0_4x4_q8_0;
1467+ return &q4_0_4x4_q8_0;
14531468 }
14541469 }
14551470 } else if (cur->type == GGML_TYPE_Q4_K) {
1456- // if (ggml_cpu_has_avx2()) { //we shall just use the regular avx2 handling, no repacking otherwise massive slowdown with gpu
1457- // if (cur->ne[1] % 8 == 0) {
1458- // return &ggml::cpu::aarch64:: q4_K_8x8_q8_K;
1459- // }
1460- // }
1471+ // if (ggml_cpu_has_avx2()) {
1472+ // if (cur->ne[1] % 8 == 0) {
1473+ // return &q4_K_8x8_q8_K;
1474+ // }
1475+ // }
14611476 } else if (cur->type == GGML_TYPE_IQ4_NL) {
14621477 if (ggml_cpu_has_neon () && ggml_cpu_has_dotprod ()) {
14631478 if (cur->ne [1 ] % 4 == 0 ) {
1464- return &ggml::cpu::repack:: iq4_nl_4x4_q8_0;
1479+ return &iq4_nl_4x4_q8_0;
14651480 }
14661481 }
14671482 }
14681483 else if (cur->type == GGML_TYPE_Q4_0_4_4) // kcpp backport old quant support
14691484 {
14701485 flag_aarch_prepacked_quant (cur->type );
1471- return &ggml::cpu::repack:: q4_0_4x4_q8_0;
1486+ return &q4_0_4x4_q8_0;
14721487 }
14731488 else if (cur->type == GGML_TYPE_Q4_0_4_8)
14741489 {
14751490 flag_aarch_prepacked_quant (cur->type );
1476- return &ggml::cpu::repack:: q4_0_4x8_q8_0;
1491+ return &q4_0_4x8_q8_0;
14771492 }
14781493 else if (cur->type == GGML_TYPE_Q4_0_8_8)
14791494 {
14801495 flag_aarch_prepacked_quant (cur->type );
1481- return &ggml::cpu::repack:: q4_0_8x8_q8_0;
1496+ return &q4_0_8x8_q8_0;
14821497 }
14831498 else if (cur->type == GGML_TYPE_IQ4_NL)
14841499 {
14851500 flag_aarch_prepacked_quant (cur->type );
1486- return &ggml::cpu::repack:: iq4_nl_4x4_q8_0;
1501+ return &iq4_nl_4x4_q8_0;
14871502 }
14881503
14891504 return nullptr ;
0 commit comments