Skip to content

Commit 477d439

Browse files
committed
repack : optimize mul_mat_id path
ggml-ci
1 parent e2661ed commit 477d439

File tree

1 file changed

+157
-74
lines changed

1 file changed

+157
-74
lines changed

ggml/src/ggml-cpu/repack.cpp

Lines changed: 157 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,15 +1312,39 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
13121312
}
13131313
case GGML_OP_MUL_MAT_ID:
13141314
{
1315-
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
1316-
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
1315+
const ggml_tensor * src0 = op->src[0];
1316+
const ggml_tensor * src1 = op->src[1];
1317+
const ggml_tensor * dst = op;
1318+
1319+
GGML_TENSOR_BINARY_OP_LOCALS
1320+
1321+
// src0 [n_embd, n_rows, n_expert]
1322+
// src1 [n_embd, n_expert_used, n_tokens]
1323+
// dst [n_rows, n_expert_used, n_tokens]
1324+
1325+
// htmp [n_embd, n_tokens, n_expert] F32
1326+
size_t size_htmp = ggml_row_size(GGML_TYPE_F32, ne00*ne12*ne02);
1327+
1328+
// hsrc1 [n_embd, n_tokens, n_expert]
1329+
size_t size_hsrc1 = ggml_row_size(PARAM_TYPE, ne00*ne12*ne02);
1330+
1331+
// hdst [n_rows, n_tokens, n_expert]
1332+
size_t size_hdst = ggml_row_size(GGML_TYPE_F32, ne01*ne12*ne02);
13171333

1318-
const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
1319-
const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
1334+
// htpe [n_expert]
1335+
size_t size_htpe = ggml_row_size(GGML_TYPE_I32, ne02);
13201336

1321-
const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
1337+
// hids [n_expert*n_tokens]
1338+
size_t size_hids = ggml_row_size(GGML_TYPE_I32, ne02*ne12);
13221339

1323-
size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
1340+
// + padding
1341+
size_htmp = GGML_PAD(size_htmp, sizeof(int64_t));
1342+
size_hsrc1 = GGML_PAD(size_hsrc1, sizeof(int64_t));
1343+
size_hdst = GGML_PAD(size_hdst, sizeof(int64_t));
1344+
size_htpe = GGML_PAD(size_htpe, sizeof(int64_t));
1345+
size_hids = GGML_PAD(size_hids, sizeof(int64_t));
1346+
1347+
size = size_htmp + size_hsrc1 + size_hdst + size_htpe + size_hids;
13241348

13251349
return true;
13261350
}
@@ -1446,77 +1470,113 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
14461470

14471471
GGML_ASSERT(src1->type == GGML_TYPE_F32);
14481472

1449-
// row groups
1450-
const int n_ids = ids->ne[0]; // n_expert_used
1451-
const int n_as = ne02; // n_expert
1452-
1453-
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
1454-
const size_t nbw2 = nbw1*ne11;
1455-
const size_t nbw3 = nbw2*ne12;
1456-
1457-
struct mmid_row_mapping {
1458-
int32_t i1;
1459-
int32_t i2;
1460-
};
1461-
1462-
GGML_ASSERT(params->wsize >=
1463-
(GGML_PAD(nbw3, sizeof(int64_t)) +
1464-
n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
1465-
);
1466-
1467-
auto * wdata = (char *)params->wdata;
1468-
auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
1469-
1470-
// total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
1471-
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1472-
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1473-
1474-
// src1: float32 => param type
1475-
for (int64_t i12 = 0; i12 < ne12; ++i12) {
1476-
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
1477-
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
1478-
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
1479-
ne10);
1480-
}
1481-
}
1473+
const int64_t ne20 = ids->ne[0];
1474+
1475+
// src0 [n_embd, n_rows, n_expert]
1476+
// src1 [n_embd, n_expert_used', n_tokens]
1477+
// src2 [n_expert_used, n_tokens]
1478+
// dst [n_rows, n_expert_used, n_tokens]
1479+
1480+
// htmp [n_embd, n_tokens, n_expert] F32
1481+
size_t size_htmp = ggml_row_size(GGML_TYPE_F32, ne00*ne12*ne02);
1482+
1483+
// hsrc1 [n_embd, n_tokens, n_expert]
1484+
size_t size_hsrc1 = ggml_row_size(PARAM_TYPE, ne00*ne12*ne02);
14821485

1483-
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
1486+
// hdst [n_rows, n_tokens, n_expert]
1487+
size_t size_hdst = ggml_row_size(GGML_TYPE_F32, ne01*ne12*ne02);
14841488

1485-
if (ith == 0) {
1486-
// initialize matrix_row_counts
1487-
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
1489+
// htpe [n_expert]
1490+
size_t size_htpe = ggml_row_size(GGML_TYPE_I32, ne02);
14881491

1489-
// group rows by src0 matrix
1490-
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
1491-
for (int32_t id = 0; id < n_ids; ++id) {
1492-
const int32_t i02 =
1493-
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
1492+
// hids [n_expert*n_tokens]
1493+
size_t size_hids = ggml_row_size(GGML_TYPE_I32, ne02*ne12);
14941494

1495-
GGML_ASSERT(i02 >= 0 && i02 < n_as);
1495+
// + padding
1496+
size_htmp = GGML_PAD(size_htmp, sizeof(int64_t));
1497+
size_hsrc1 = GGML_PAD(size_hsrc1, sizeof(int64_t));
1498+
size_hdst = GGML_PAD(size_hdst, sizeof(int64_t));
1499+
size_htpe = GGML_PAD(size_htpe, sizeof(int64_t));
1500+
size_hids = GGML_PAD(size_hids, sizeof(int64_t));
14961501

1497-
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
1498-
matrix_row_counts[i02] += 1;
1502+
char * wdata_htmp = (char *) params->wdata;
1503+
char * wdata_hsrc1 = (char *) params->wdata + size_htmp;
1504+
char * wdata_hdst = (char *) params->wdata + size_htmp + size_hsrc1;
1505+
char * wdata_htpe = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst;
1506+
char * wdata_hids = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst + size_htpe;
1507+
1508+
const size_t nbht1 = ggml_row_size(GGML_TYPE_F32, ne00);
1509+
const size_t nbht2 = nbht1*ne12;
1510+
1511+
const size_t nbh11 = ggml_row_size(PARAM_TYPE, ne00);
1512+
const size_t nbh12 = nbh11*ne12;
1513+
1514+
const size_t nbh1 = ggml_row_size(GGML_TYPE_F32, ne01);
1515+
const size_t nbh2 = nbh1*ne12;
1516+
1517+
char * htmp = (char *)(wdata_htmp);
1518+
char * hsrc1 = (char *)(wdata_hsrc1);
1519+
char * hdst = (char *)(wdata_hdst);
1520+
int32_t * htpe = (int32_t *)(wdata_htpe);
1521+
int32_t * hids = (int32_t *)(wdata_hids);
1522+
1523+
for (int64_t i02 = ith; i02 < ne02; i02 += nth) {
1524+
htpe[i02] = 0;
1525+
}
1526+
1527+
// src1 (float32) => htmp (float32)
1528+
for (int64_t i12 = 0; i12 < ne12; ++i12) { // n_tokens
1529+
for (int64_t i20 = 0; i20 < ne20; ++i20) { // n_expert_used
1530+
// the selected expert
1531+
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + i12*ids->nb[1] + i20*ids->nb[0]);
1532+
1533+
if (i02 % nth != ith) {
1534+
continue;
14991535
}
1536+
1537+
memcpy( htmp + i02*nbht2 + htpe[i02]*nbht1,
1538+
(char *) src1->data + i12*nb12 + (i20%ne11)*nb11,
1539+
ggml_row_size(GGML_TYPE_F32, ne10));
1540+
1541+
hids[i12*ne20 + i20] = i02*ne12 + htpe[i02];
1542+
htpe[i02]++;
1543+
}
1544+
}
1545+
1546+
// htmp (float32) => hsrc1 (param type)
1547+
for (int64_t i02 = 0; i02 < ne02; ++i02) { // n_expert
1548+
if (i02 % nth != ith) {
1549+
continue;
1550+
}
1551+
1552+
const int64_t neh11 = htpe[i02];
1553+
1554+
for (int64_t i11 = 0; i11 < neh11 - neh11 % 4; i11 += 4) {
1555+
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>(
1556+
(float *) (htmp + i11*nbht1 + i02*nbht2),
1557+
(void *) (hsrc1 + i11*nbh11 + i02*nbh12), 4, ne10);
1558+
}
1559+
1560+
for (int64_t i11 = neh11 - neh11 % 4; i11 < neh11; i11 += 1) {
1561+
from_float(
1562+
(float *) (htmp + i11*nbht1 + i02*nbht2),
1563+
(void *) (hsrc1 + i11*nbh11 + i02*nbh12), ne10);
15001564
}
15011565
}
15021566

15031567
ggml_barrier(params->threadpool);
15041568

1505-
// compute each matrix multiplication in sequence
1506-
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
1507-
const int64_t cne1 = matrix_row_counts[cur_a];
1569+
for (int64_t i02 = 0; i02 < ne02; ++i02) { // n_expert
1570+
const int64_t neh11 = htpe[i02];
15081571

1509-
if (cne1 == 0) {
1572+
if (neh11 == 0) {
15101573
continue;
15111574
}
15121575

1513-
const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
1514-
1515-
//const int64_t nr0 = ne01; // src0 rows
1516-
const int64_t nr1 = cne1; // src1 rows
1576+
const auto * src0_cur = (const char *) src0->data + i02*nb02;
15171577

1518-
int64_t src0_cur_start = (ith * ne01) / nth;
1519-
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
1578+
int64_t src0_cur_start = ((ith )*ne01)/nth;
1579+
int64_t src0_cur_end = ((ith + 1)*ne01)/nth;
15201580

15211581
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
15221582
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
@@ -1525,26 +1585,49 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
15251585
return;
15261586
}
15271587

1528-
for (int ir1 = 0; ir1 < nr1; ir1++) {
1529-
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
1588+
#if 1
1589+
if (neh11 > 3) {
1590+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1591+
(float *)(hdst + 0*nbh1 + i02*nbh2) + src0_cur_start, ne01,
1592+
src0_cur + src0_cur_start*nb01,
1593+
hsrc1 + 0*nbh11 + i02*nbh12, neh11 - neh11 % 4, src0_cur_end - src0_cur_start);
1594+
}
1595+
for (int64_t i11 = neh11 - neh11 % 4; i11 < neh11; ++i11) {
1596+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1597+
(float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01,
1598+
src0_cur + src0_cur_start*nb01,
1599+
hsrc1 + i11*nbh11 + i02*nbh12, 1, src0_cur_end - src0_cur_start);
1600+
}
1601+
#else
1602+
for (int64_t i11 = 0; i11 < neh11; ++i11) {
1603+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1604+
(float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01,
1605+
src0_cur + src0_cur_start * nb01,
1606+
hsrc1 + i11*nbh11 + i02*nbh12, 1, src0_cur_end - src0_cur_start);
1607+
}
1608+
#endif
1609+
}
1610+
1611+
ggml_barrier(params->threadpool);
15301612

1531-
const int id = row_mapping.i1; // selected expert index
1613+
for (int64_t i21 = 0; i21 < ne12; ++i21) { // n_tokens
1614+
for (int64_t i20 = 0; i20 < ne20; ++i20) { // n_expert_used
1615+
const int32_t idx = i21*ne20 + i20;
15321616

1533-
const int64_t i11 = id % ne11;
1534-
const int64_t i12 = row_mapping.i2; // row index in src1
1617+
if (idx % nth != ith) {
1618+
continue;
1619+
}
15351620

1536-
const int64_t i1 = id; // selected expert index
1537-
const int64_t i2 = i12; // row
1621+
const int32_t id = hids[idx];
15381622

1539-
const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
1623+
const int ide = id/ne12;
1624+
const int idt = id%ne12;
15401625

1541-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1542-
(float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
1543-
src0_cur + src0_cur_start * nb01,
1544-
src1_col, 1, src0_cur_end - src0_cur_start);
1626+
memcpy(
1627+
(char *) dst->data + i20*nb1 + i21*nb2,
1628+
hdst + idt*nbh1 + ide*nbh2, ggml_row_size(GGML_TYPE_F32, ne01));
15451629
}
15461630
}
1547-
#undef MMID_MATRIX_ROW
15481631
}
15491632

15501633
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {

0 commit comments

Comments
 (0)