@@ -1312,15 +1312,39 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
13121312 }
13131313 case GGML_OP_MUL_MAT_ID:
13141314 {
1315- size = ggml_row_size (PARAM_TYPE, ggml_nelements (op->src [1 ]));
1316- size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1315+ const ggml_tensor * src0 = op->src [0 ];
1316+ const ggml_tensor * src1 = op->src [1 ];
1317+ const ggml_tensor * dst = op;
1318+
1319+ GGML_TENSOR_BINARY_OP_LOCALS
1320+
1321+ // src0 [n_embd, n_rows, n_expert]
1322+ // src1 [n_embd, n_expert_used, n_tokens]
1323+ // dst [n_rows, n_expert_used, n_tokens]
1324+
1325+ // htmp [n_embd, n_tokens, n_expert] F32
1326+ size_t size_htmp = ggml_row_size (GGML_TYPE_F32, ne00*ne12*ne02);
1327+
1328+ // hsrc1 [n_embd, n_tokens, n_expert]
1329+ size_t size_hsrc1 = ggml_row_size (PARAM_TYPE, ne00*ne12*ne02);
1330+
1331+ // hdst [n_rows, n_tokens, n_expert]
1332+ size_t size_hdst = ggml_row_size (GGML_TYPE_F32, ne01*ne12*ne02);
13171333
1318- const int64_t ne02 = op-> src [ 0 ]-> ne [ 2 ]; // n_as, n_expert
1319- const int64_t ne12 = op-> src [ 1 ]-> ne [ 2 ]; // n_tokens
1334+ // htpe [ n_expert]
1335+ size_t size_htpe = ggml_row_size (GGML_TYPE_I32, ne02);
13201336
1321- const size_t sizeof_mmid_row_mapping = sizeof (int64_t );
1337+ // hids [n_expert*n_tokens]
1338+ size_t size_hids = ggml_row_size (GGML_TYPE_I32, ne02*ne12);
13221339
1323- size += sizeof_mmid_row_mapping*ne02*(ne12 + 1 );
1340+ // + padding
1341+ size_htmp = GGML_PAD (size_htmp, sizeof (int64_t ));
1342+ size_hsrc1 = GGML_PAD (size_hsrc1, sizeof (int64_t ));
1343+ size_hdst = GGML_PAD (size_hdst, sizeof (int64_t ));
1344+ size_htpe = GGML_PAD (size_htpe, sizeof (int64_t ));
1345+ size_hids = GGML_PAD (size_hids, sizeof (int64_t ));
1346+
1347+ size = size_htmp + size_hsrc1 + size_hdst + size_htpe + size_hids;
13241348
13251349 return true ;
13261350 }
@@ -1446,77 +1470,113 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
14461470
14471471 GGML_ASSERT (src1->type == GGML_TYPE_F32);
14481472
1449- // row groups
1450- const int n_ids = ids->ne [0 ]; // n_expert_used
1451- const int n_as = ne02; // n_expert
1452-
1453- const size_t nbw1 = ggml_row_size (PARAM_TYPE, ne10);
1454- const size_t nbw2 = nbw1*ne11;
1455- const size_t nbw3 = nbw2*ne12;
1456-
1457- struct mmid_row_mapping {
1458- int32_t i1;
1459- int32_t i2;
1460- };
1461-
1462- GGML_ASSERT (params->wsize >=
1463- (GGML_PAD (nbw3, sizeof (int64_t )) +
1464- n_as*(ne12 + 1 )*sizeof (mmid_row_mapping))
1465- );
1466-
1467- auto * wdata = (char *)params->wdata ;
1468- auto * wdata_src1_end = (char *)wdata + GGML_PAD (nbw3, sizeof (int64_t ));
1469-
1470- // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
1471- auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1472- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1473-
1474- // src1: float32 => param type
1475- for (int64_t i12 = 0 ; i12 < ne12; ++i12) {
1476- for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
1477- from_float ((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
1478- (void *) (wdata + i12 * nbw2 + i11 * nbw1),
1479- ne10);
1480- }
1481- }
1473+ const int64_t ne20 = ids->ne [0 ];
1474+
1475+ // src0 [n_embd, n_rows, n_expert]
1476+ // src1 [n_embd, n_expert_used', n_tokens]
1477+ // src2 [n_expert_used, n_tokens]
1478+ // dst [n_rows, n_expert_used, n_tokens]
1479+
1480+ // htmp [n_embd, n_tokens, n_expert] F32
1481+ size_t size_htmp = ggml_row_size (GGML_TYPE_F32, ne00*ne12*ne02);
1482+
1483+ // hsrc1 [n_embd, n_tokens, n_expert]
1484+ size_t size_hsrc1 = ggml_row_size (PARAM_TYPE, ne00*ne12*ne02);
14821485
1483- #define MMID_MATRIX_ROW (row_id, i1 ) matrix_rows[(row_id) * ne12 + (i1)]
1486+ // hdst [n_rows, n_tokens, n_expert]
1487+ size_t size_hdst = ggml_row_size (GGML_TYPE_F32, ne01*ne12*ne02);
14841488
1485- if (ith == 0 ) {
1486- // initialize matrix_row_counts
1487- memset (matrix_row_counts, 0 , n_as * sizeof (int64_t ));
1489+ // htpe [n_expert]
1490+ size_t size_htpe = ggml_row_size (GGML_TYPE_I32, ne02);
14881491
1489- // group rows by src0 matrix
1490- for (int32_t iid1 = 0 ; iid1 < ids->ne [1 ]; ++iid1) {
1491- for (int32_t id = 0 ; id < n_ids; ++id) {
1492- const int32_t i02 =
1493- *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb [1 ] + id * ids->nb [0 ]);
1492+ // hids [n_expert*n_tokens]
1493+ size_t size_hids = ggml_row_size (GGML_TYPE_I32, ne02*ne12);
14941494
1495- GGML_ASSERT (i02 >= 0 && i02 < n_as);
1495+ // + padding
1496+ size_htmp = GGML_PAD (size_htmp, sizeof (int64_t ));
1497+ size_hsrc1 = GGML_PAD (size_hsrc1, sizeof (int64_t ));
1498+ size_hdst = GGML_PAD (size_hdst, sizeof (int64_t ));
1499+ size_htpe = GGML_PAD (size_htpe, sizeof (int64_t ));
1500+ size_hids = GGML_PAD (size_hids, sizeof (int64_t ));
14961501
1497- MMID_MATRIX_ROW (i02, matrix_row_counts[i02]) = { id, iid1 };
1498- matrix_row_counts[i02] += 1 ;
1502+ char * wdata_htmp = (char *) params->wdata ;
1503+ char * wdata_hsrc1 = (char *) params->wdata + size_htmp;
1504+ char * wdata_hdst = (char *) params->wdata + size_htmp + size_hsrc1;
1505+ char * wdata_htpe = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst;
1506+ char * wdata_hids = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst + size_htpe;
1507+
1508+ const size_t nbht1 = ggml_row_size (GGML_TYPE_F32, ne00);
1509+ const size_t nbht2 = nbht1*ne12;
1510+
1511+ const size_t nbh11 = ggml_row_size (PARAM_TYPE, ne00);
1512+ const size_t nbh12 = nbh11*ne12;
1513+
1514+ const size_t nbh1 = ggml_row_size (GGML_TYPE_F32, ne01);
1515+ const size_t nbh2 = nbh1*ne12;
1516+
1517+ char * htmp = (char *)(wdata_htmp);
1518+ char * hsrc1 = (char *)(wdata_hsrc1);
1519+ char * hdst = (char *)(wdata_hdst);
1520+ int32_t * htpe = (int32_t *)(wdata_htpe);
1521+ int32_t * hids = (int32_t *)(wdata_hids);
1522+
1523+ for (int64_t i02 = ith; i02 < ne02; i02 += nth) {
1524+ htpe[i02] = 0 ;
1525+ }
1526+
1527+ // src1 (float32) => htmp (float32)
1528+ for (int64_t i12 = 0 ; i12 < ne12; ++i12) { // n_tokens
1529+ for (int64_t i20 = 0 ; i20 < ne20; ++i20) { // n_expert_used
1530+ // the selected expert
1531+ const int32_t i02 = *(const int32_t *) ((const char *) ids->data + i12*ids->nb [1 ] + i20*ids->nb [0 ]);
1532+
1533+ if (i02 % nth != ith) {
1534+ continue ;
14991535 }
1536+
1537+ memcpy ( htmp + i02*nbht2 + htpe[i02]*nbht1,
1538+ (char *) src1->data + i12*nb12 + (i20%ne11)*nb11,
1539+ ggml_row_size (GGML_TYPE_F32, ne10));
1540+
1541+ hids[i12*ne20 + i20] = i02*ne12 + htpe[i02];
1542+ htpe[i02]++;
1543+ }
1544+ }
1545+
1546+ // htmp (float32) => hsrc1 (param type)
1547+ for (int64_t i02 = 0 ; i02 < ne02; ++i02) { // n_expert
1548+ if (i02 % nth != ith) {
1549+ continue ;
1550+ }
1551+
1552+ const int64_t neh11 = htpe[i02];
1553+
1554+ for (int64_t i11 = 0 ; i11 < neh11 - neh11 % 4 ; i11 += 4 ) {
1555+ ggml_quantize_mat_t <INTER_SIZE, PARAM_TYPE>(
1556+ (float *) (htmp + i11*nbht1 + i02*nbht2),
1557+ (void *) (hsrc1 + i11*nbh11 + i02*nbh12), 4 , ne10);
1558+ }
1559+
1560+ for (int64_t i11 = neh11 - neh11 % 4 ; i11 < neh11; i11 += 1 ) {
1561+ from_float (
1562+ (float *) (htmp + i11*nbht1 + i02*nbht2),
1563+ (void *) (hsrc1 + i11*nbh11 + i02*nbh12), ne10);
15001564 }
15011565 }
15021566
15031567 ggml_barrier (params->threadpool );
15041568
1505- // compute each matrix multiplication in sequence
1506- for (int cur_a = 0 ; cur_a < n_as; ++cur_a) {
1507- const int64_t cne1 = matrix_row_counts[cur_a];
1569+ for (int64_t i02 = 0 ; i02 < ne02; ++i02) { // n_expert
1570+ const int64_t neh11 = htpe[i02];
15081571
1509- if (cne1 == 0 ) {
1572+ if (neh11 == 0 ) {
15101573 continue ;
15111574 }
15121575
1513- const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
1514-
1515- // const int64_t nr0 = ne01; // src0 rows
1516- const int64_t nr1 = cne1; // src1 rows
1576+ const auto * src0_cur = (const char *) src0->data + i02*nb02;
15171577
1518- int64_t src0_cur_start = (ith * ne01) / nth;
1519- int64_t src0_cur_end = ((ith + 1 ) * ne01) / nth;
1578+ int64_t src0_cur_start = (( ith )* ne01)/ nth;
1579+ int64_t src0_cur_end = ((ith + 1 )* ne01)/ nth;
15201580
15211581 src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
15221582 src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
@@ -1525,26 +1585,49 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
15251585 return ;
15261586 }
15271587
1528- for (int ir1 = 0 ; ir1 < nr1; ir1++) {
1529- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW (cur_a, ir1);
1588+ #if 1
1589+ if (neh11 > 3 ) {
1590+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1591+ (float *)(hdst + 0 *nbh1 + i02*nbh2) + src0_cur_start, ne01,
1592+ src0_cur + src0_cur_start*nb01,
1593+ hsrc1 + 0 *nbh11 + i02*nbh12, neh11 - neh11 % 4 , src0_cur_end - src0_cur_start);
1594+ }
1595+ for (int64_t i11 = neh11 - neh11 % 4 ; i11 < neh11; ++i11) {
1596+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1597+ (float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01,
1598+ src0_cur + src0_cur_start*nb01,
1599+ hsrc1 + i11*nbh11 + i02*nbh12, 1 , src0_cur_end - src0_cur_start);
1600+ }
1601+ #else
1602+ for (int64_t i11 = 0; i11 < neh11; ++i11) {
1603+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1604+ (float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01,
1605+ src0_cur + src0_cur_start * nb01,
1606+ hsrc1 + i11*nbh11 + i02*nbh12, 1, src0_cur_end - src0_cur_start);
1607+ }
1608+ #endif
1609+ }
1610+
1611+ ggml_barrier (params->threadpool );
15301612
1531- const int id = row_mapping.i1 ; // selected expert index
1613+ for (int64_t i21 = 0 ; i21 < ne12; ++i21) { // n_tokens
1614+ for (int64_t i20 = 0 ; i20 < ne20; ++i20) { // n_expert_used
1615+ const int32_t idx = i21*ne20 + i20;
15321616
1533- const int64_t i11 = id % ne11;
1534- const int64_t i12 = row_mapping.i2 ; // row index in src1
1617+ if (idx % nth != ith) {
1618+ continue ;
1619+ }
15351620
1536- const int64_t i1 = id; // selected expert index
1537- const int64_t i2 = i12; // row
1621+ const int32_t id = hids[idx];
15381622
1539- const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
1623+ const int ide = id/ne12;
1624+ const int idt = id%ne12;
15401625
1541- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1542- (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
1543- src0_cur + src0_cur_start * nb01,
1544- src1_col, 1 , src0_cur_end - src0_cur_start);
1626+ memcpy (
1627+ (char *) dst->data + i20*nb1 + i21*nb2,
1628+ hdst + idt*nbh1 + ide*nbh2, ggml_row_size (GGML_TYPE_F32, ne01));
15451629 }
15461630 }
1547- #undef MMID_MATRIX_ROW
15481631 }
15491632
15501633 int repack (struct ggml_tensor * t, const void * data, size_t data_size) override {
0 commit comments