Skip to content

Commit c6c4de2

Browse files
committed
ggml-cpu : optimize 128-bit rvv ggml_vec_dot_q4_K_q8_K
1 parent c7786e7 commit c6c4de2

File tree

1 file changed

+152
-63
lines changed

1 file changed

+152
-63
lines changed

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 152 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,15 +1270,18 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12701270
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
12711271
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
12721272

1273-
int tmp, tmp2, sumi;
1273+
int tmp, tmp2;
1274+
float ftmp, ft2;
1275+
12741276
__asm__ __volatile__(
1275-
"vsetivli zero, 12, e8, m1\n\t"
1276-
"vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1277-
"vsetivli zero, 4, e32, m1\n\t"
1277+
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
1278+
"vle32.v v1, (%[s6b])\n\t"
1279+
"vslide1down.vx v1, v1, zero\n\t"
1280+
"vmv.v.x v16, zero\n\t"
12781281
"vslidedown.vi v2, v1, 2\n\t"
12791282
"vmv1r.v v3, v2\n\t"
12801283
"vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1281-
"vsetivli zero, 2, e32, m1\n\t"
1284+
"vsetivli zero, 2, e32, m1, ta, ma\n\t"
12821285
"vmv.v.i v4, 4\n\t"
12831286
"vand.vx v8, v1, %[kmask1]\n\t"
12841287
"vslide1up.vx v5, v4, zero\n\t" // {0, 4}
@@ -1292,81 +1295,167 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12921295
"vor.vv v1, v6, v2\n\t"
12931296
"vsse32.v v8, (%[utmp]), %[t2]\n\t"
12941297
"vsse32.v v1, (%[t1]), %[t2]\n\t"
1295-
"vsetivli zero, 8, e16, m1\n\t"
1298+
"vsetivli zero, 8, e16, m1, ta, ma\n\t"
12961299
"vle32.v v2, (%[bsums])\n\t"
12971300
"vnsrl.wi v0, v2, 0\n\t"
12981301
"vnsrl.wi v1, v2, 16\n\t"
12991302
"vadd.vv v2, v0, v1\n\t"
13001303
"vle8.v v3, (%[mins])\n\t"
13011304
"vzext.vf2 v4, v3\n\t"
13021305
"vwmul.vv v6, v4, v2\n\t"
1303-
"vmv.v.x v0, zero\n\t"
1304-
"vsetivli zero, 8, e32, m2\n\t"
1305-
"vredsum.vs v0, v6, v0\n\t"
1306-
"vmv.x.s %[sumi], v0"
1307-
: [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1306+
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
1307+
"vredsum.vs v0, v6, v16\n\t"
1308+
"vredsum.vs v0, v7, v0\n\t"
1309+
"vfcvt.f.x.v v0, v0\n\t"
1310+
"vfmv.f.s %[ftmp], v0"
1311+
: [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [ftmp] "=&f" (ftmp)
13081312
: [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1309-
, [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1313+
, [s6b] "r" (&x[i]), [kmask1] "r" (kmask1)
13101314
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
13111315
: "memory"
13121316
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
13131317
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
13141318
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
13151319
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
13161320
);
1317-
sumf -= dmin * sumi;
1318-
1319-
const uint8_t * restrict q4 = x[i].qs;
1320-
const int8_t * restrict q8 = y[i].qs;
1321-
1322-
sumi = 0;
1321+
sumf -= dmin * ftmp;
1322+
1323+
const uint8_t * restrict q40 = x[i].qs + 0;
1324+
const uint8_t * restrict q41 = x[i].qs + 16;
1325+
const uint8_t * restrict q42 = x[i].qs + 32;
1326+
const uint8_t * restrict q43 = x[i].qs + 48;
1327+
const int8_t * restrict q80;
1328+
const int8_t * restrict q81;
1329+
const int8_t * restrict q82;
1330+
const int8_t * restrict q83;
1331+
1332+
ftmp = 0;
13231333
const uint8_t * scale = scales;
13241334

1325-
for (int j = 0; j < QK_K/128; ++j) {
1326-
int vl128 = 128, vl64 = 64, vl32 = 32;
1327-
__asm__ __volatile__(
1328-
"vsetvli zero, %[vl128], e8, m8\n\t"
1329-
"vle8.v v8, (%[q8])\n\t"
1330-
"vsetvli zero, %[vl64], e8, m4\n\t"
1331-
"vle8.v v0, (%[q4])\n\t"
1332-
"vsrl.vi v4, v0, 4\n\t"
1333-
"vand.vi v0, v0, 0xF\n\t"
1334-
"vsetvli zero, %[vl32], e8, m2\n\t"
1335-
"vwmul.vv v28, v6, v14\n\t"
1336-
"vwmul.vv v20, v4, v10\n\t"
1337-
"vwmul.vv v24, v2, v12\n\t"
1338-
"vwmul.vv v16, v0, v8\n\t"
1339-
"vsetivli zero, 4, e32, m1\n\t"
1340-
"vle8.v v2, (%[scale])\n\t"
1341-
"vmv.v.x v0, zero\n\t"
1342-
"vzext.vf4 v1, v2\n\t"
1343-
"vsetvli zero, %[vl32], e16, m4\n\t"
1344-
"vwredsum.vs v6, v24, v0\n\t"
1345-
"vwredsum.vs v7, v28, v0\n\t"
1346-
"vwredsum.vs v4, v16, v0\n\t"
1347-
"vwredsum.vs v5, v20, v0\n\t"
1348-
"vsetivli zero, 4, e32, m1\n\t"
1349-
"vslideup.vi v6, v7, 1\n\t"
1350-
"vslideup.vi v4, v5, 1\n\t"
1351-
"vslideup.vi v4, v6, 2\n\t"
1352-
"vmul.vv v8, v4, v1\n\t"
1353-
"vredsum.vs v0, v8, v0\n\t"
1354-
"vmv.x.s %[tmp], v0\n\t"
1355-
"add %[sumi], %[sumi], %[tmp]"
1356-
: [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1357-
: [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1358-
, [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1359-
: "memory"
1360-
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1361-
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1362-
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1363-
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1364-
);
1365-
1366-
q4 += 64; q8 += 128; scale += 4;
1367-
}
1368-
1369-
sumf += d * sumi;
1335+
int s0, s1, s2, s3;
1336+
__asm__ __volatile__(
1337+
"vsetivli zero, 16, e8, m1, ta, ma\n\t"
1338+
"vle8.v v0, (%[q40])\n\t"
1339+
"addi %[q80], %[ys], 0\n\t"
1340+
"addi %[q40], %[q40], 64\n\t"
1341+
"vle8.v v1, (%[q41])\n\t"
1342+
"addi %[q81], %[ys], 16\n\t"
1343+
"addi %[q41], %[q41], 64\n\t"
1344+
"vle8.v v2, (%[q42])\n\t"
1345+
"addi %[q82], %[ys], 32\n\t"
1346+
"addi %[q42], %[q42], 64\n\t"
1347+
"vle8.v v3, (%[q43])\n\t"
1348+
"addi %[q83], %[ys], 48\n\t"
1349+
"addi %[q43], %[q43], 64\n\t"
1350+
"vle8.v v8, (%[q80])\n\t"
1351+
"vsrl.vi v4, v0, 4\n\t"
1352+
"addi %[q80], %[q80], 64\n\t"
1353+
"vle8.v v9, (%[q81])\n\t"
1354+
"vand.vi v0, v0, 0xF\n\t"
1355+
"addi %[q81], %[q81], 64\n\t"
1356+
"vle8.v v10, (%[q82])\n\t"
1357+
"vsrl.vi v5, v1, 4\n\t"
1358+
"addi %[q82], %[q82], 64\n\t"
1359+
"vle8.v v11, (%[q83])\n\t"
1360+
"vand.vi v1, v1, 0xF\n\t"
1361+
"addi %[q83], %[q83], 64\n\t"
1362+
"vle8.v v12, (%[q80])\n\t"
1363+
"vsrl.vi v6, v2, 4\n\t"
1364+
"addi %[q80], %[q80], 64\n\t"
1365+
"vle8.v v13, (%[q81])\n\t"
1366+
"vand.vi v2, v2, 0xF\n\t"
1367+
"addi %[q81], %[q81], 64\n\t"
1368+
"vle8.v v14, (%[q82])\n\t"
1369+
"vsrl.vi v7, v3, 4\n\t"
1370+
"addi %[q82], %[q82], 64\n\t"
1371+
"vle8.v v15, (%[q83])\n\t"
1372+
"vand.vi v3, v3, 0xF\n\t"
1373+
"addi %[q83], %[q83], 64\n\t"
1374+
"vwmul.vv v16, v0, v8\n\t"
1375+
"vwmul.vv v24, v2, v12\n\t"
1376+
"vwmul.vv v20, v4, v10\n\t"
1377+
"vwmul.vv v28, v6, v14\n\t"
1378+
"vwmacc.vv v16, v1, v9\n\t"
1379+
"vwmacc.vv v24, v3, v13\n\t"
1380+
"vwmacc.vv v20, v5, v11\n\t"
1381+
"vwmacc.vv v28, v7, v15\n\t"
1382+
"vle8.v v0, (%[q40])\n\t"
1383+
"addi %[q40], %[q80], 64\n\t"
1384+
"vle8.v v1, (%[q41])\n\t"
1385+
"addi %[q41], %[q81], 64\n\t"
1386+
"vle8.v v2, (%[q42])\n\t"
1387+
"addi %[q42], %[q82], 64\n\t"
1388+
"vle8.v v3, (%[q43])\n\t"
1389+
"addi %[q43], %[q83], 64\n\t"
1390+
"vle8.v v8, (%[q80])\n\t"
1391+
"vsrl.vi v4, v0, 4\n\t"
1392+
"vle8.v v9, (%[q81])\n\t"
1393+
"vand.vi v0, v0, 0xF\n\t"
1394+
"vle8.v v10, (%[q82])\n\t"
1395+
"vsrl.vi v5, v1, 4\n\t"
1396+
"vle8.v v11, (%[q83])\n\t"
1397+
"vand.vi v1, v1, 0xF\n\t"
1398+
"vle8.v v12, (%[q40])\n\t"
1399+
"vsrl.vi v6, v2, 4\n\t"
1400+
"vle8.v v13, (%[q41])\n\t"
1401+
"vand.vi v2, v2, 0xF\n\t"
1402+
"vle8.v v14, (%[q42])\n\t"
1403+
"vsrl.vi v7, v3, 4\n\t"
1404+
"vle8.v v15, (%[q43])\n\t"
1405+
"vand.vi v3, v3, 0xF\n\t"
1406+
"vwmul.vv v18, v0, v8\n\t"
1407+
"vwmul.vv v26, v2, v12\n\t"
1408+
"vwmul.vv v22, v4, v10\n\t"
1409+
"vwmul.vv v30, v6, v14\n\t"
1410+
"vwmacc.vv v18, v1, v9\n\t"
1411+
"vwmacc.vv v26, v3, v13\n\t"
1412+
"vwmacc.vv v22, v5, v11\n\t"
1413+
"vwmacc.vv v30, v7, v15\n\t"
1414+
"vmv.v.x v0, zero\n\t"
1415+
"vsetivli zero, 16, e16, m2, ta, ma\n\t"
1416+
"vwredsum.vs v4, v16, v0\n\t"
1417+
"lbu %[s0], 0(%[scale])\n\t"
1418+
"vwredsum.vs v5, v20, v0\n\t"
1419+
"lbu %[s1], 1(%[scale])\n\t"
1420+
"vwredsum.vs v6, v24, v0\n\t"
1421+
"lbu %[s2], 2(%[scale])\n\t"
1422+
"vwredsum.vs v7, v28, v0\n\t"
1423+
"lbu %[s3], 3(%[scale])\n\t"
1424+
"vwredsum.vs v8, v18, v0\n\t"
1425+
"lbu %[q40], 4(%[scale])\n\t"
1426+
"vwredsum.vs v9, v22, v0\n\t"
1427+
"lbu %[q41], 5(%[scale])\n\t"
1428+
"vwredsum.vs v10, v26, v0\n\t"
1429+
"lbu %[q42], 6(%[scale])\n\t"
1430+
"vwredsum.vs v11, v30, v0\n\t"
1431+
"lbu %[q43], 7(%[scale])\n\t"
1432+
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
1433+
"vmul.vx v0, v4, %[s0]\n\t"
1434+
"vmul.vx v1, v8, %[q40]\n\t"
1435+
"vmacc.vx v0, %[s1], v5\n\t"
1436+
"vmacc.vx v1, %[q41], v9\n\t"
1437+
"vmacc.vx v0, %[s2], v6\n\t"
1438+
"vmacc.vx v1, %[q42], v10\n\t"
1439+
"vmacc.vx v0, %[s3], v7\n\t"
1440+
"vmacc.vx v1, %[q43], v11\n\t"
1441+
"vfcvt.f.x.v v0, v0\n\t"
1442+
"vfcvt.f.x.v v1, v1\n\t"
1443+
"vfmv.f.s %[ft2], v0\n\t"
1444+
"vfmv.f.s %[ftmp], v1\n\t"
1445+
"fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
1446+
"fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
1447+
: [tmp] "=&r" (tmp), [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
1448+
, [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
1449+
, [q40] "+&r" (q40), [q41] "+&r" (q41), [q42] "+&r" (q42), [q43] "+&r" (q43)
1450+
, [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
1451+
, [scale] "+&r" (scale)
1452+
: [d] "f" (d), [ys] "r" (y[i].qs)
1453+
: "memory"
1454+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1455+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1456+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1457+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1458+
);
13701459
}
13711460
break;
13721461
default:

0 commit comments

Comments
 (0)