Skip to content

Commit 624b291

Browse files
committed
ggml-cpu : optimize rvv ggml_vec_dot_q4_K_q8_K
1 parent 20d2017 commit 624b291

File tree

1 file changed

+46
-58
lines changed

1 file changed

+46
-58
lines changed

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,8 +1272,18 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12721272

12731273
int tmp, tmp2;
12741274
float ftmp, ft2;
1275+
const uint8_t * restrict q40;
1276+
const uint8_t * restrict q41;
1277+
const uint8_t * restrict q42;
1278+
const uint8_t * restrict q43;
1279+
const int8_t * restrict q80;
1280+
const int8_t * restrict q81;
1281+
const int8_t * restrict q82;
1282+
const int8_t * restrict q83;
1283+
int s0, s1, s2, s3;
12751284

12761285
__asm__ __volatile__(
1286+
"li %[s1], 8\n\t"
12771287
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
12781288
"vle32.v v1, (%[s6b])\n\t"
12791289
"vslide1down.vx v1, v1, zero\n\t"
@@ -1287,14 +1297,13 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12871297
"vslide1up.vx v5, v4, zero\n\t" // {0, 4}
12881298
"vsrl.vi v6, v1, 6\n\t"
12891299
"vsrl.vv v7, v2, v5\n\t"
1300+
"vsse32.v v8, (%[utmp]), %[s1]\n\t"
12901301
"vand.vx v0, v6, %[kmask3]\n\t"
12911302
"vand.vx v2, v7, %[kmask2]\n\t"
12921303
"vsll.vi v6, v0, 4\n\t"
1293-
"li %[t2], 8\n\t"
1294-
"addi %[t1], %[utmp], 4\n\t"
1304+
"addi %[s0], %[utmp], 4\n\t"
12951305
"vor.vv v1, v6, v2\n\t"
1296-
"vsse32.v v8, (%[utmp]), %[t2]\n\t"
1297-
"vsse32.v v1, (%[t1]), %[t2]\n\t"
1306+
"vsse32.v v1, (%[s0]), %[s1]\n\t"
12981307
"vsetivli zero, 8, e16, m1, ta, ma\n\t"
12991308
"vle32.v v2, (%[bsums])\n\t"
13001309
"vnsrl.wi v0, v2, 0\n\t"
@@ -1307,107 +1316,84 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
13071316
"vredsum.vs v0, v6, v16\n\t"
13081317
"vredsum.vs v0, v7, v0\n\t"
13091318
"vfcvt.f.x.v v0, v0\n\t"
1310-
"vfmv.f.s %[ftmp], v0"
1311-
: [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [ftmp] "=&f" (ftmp)
1312-
: [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1313-
, [s6b] "r" (&x[i]), [kmask1] "r" (kmask1)
1314-
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1315-
: "memory"
1316-
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1317-
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1318-
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1319-
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1320-
);
1321-
sumf -= dmin * ftmp;
1322-
1323-
const uint8_t * restrict q40 = x[i].qs + 0;
1324-
const uint8_t * restrict q41 = x[i].qs + 16;
1325-
const uint8_t * restrict q42 = x[i].qs + 32;
1326-
const uint8_t * restrict q43 = x[i].qs + 48;
1327-
const int8_t * restrict q80;
1328-
const int8_t * restrict q81;
1329-
const int8_t * restrict q82;
1330-
const int8_t * restrict q83;
1331-
1332-
ftmp = 0;
1333-
const uint8_t * scale = scales;
1334-
1335-
int s0, s1, s2, s3;
1336-
__asm__ __volatile__(
1319+
"vfmv.f.s %[ftmp], v0\n\t"
13371320
"vsetivli zero, 16, e8, m1, ta, ma\n\t"
1338-
"vle8.v v0, (%[q40])\n\t"
1339-
"addi %[q80], %[ys], 0\n\t"
1340-
"addi %[q40], %[q40], 64\n\t"
1321+
"vle8.v v0, (%[xs])\n\t"
1322+
"fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
1323+
"addi %[q40], %[xs], 64\n\t"
1324+
"addi %[q41], %[xs], 16\n\t"
1325+
"addi %[q42], %[xs], 32\n\t"
1326+
"addi %[q43], %[xs], 48\n\t"
1327+
"addi %[q80], %[ys], 64\n\t"
13411328
"vle8.v v1, (%[q41])\n\t"
1329+
"vle8.v v2, (%[q42])\n\t"
13421330
"addi %[q81], %[ys], 16\n\t"
13431331
"addi %[q41], %[q41], 64\n\t"
1344-
"vle8.v v2, (%[q42])\n\t"
13451332
"addi %[q82], %[ys], 32\n\t"
1346-
"addi %[q42], %[q42], 64\n\t"
13471333
"vle8.v v3, (%[q43])\n\t"
1334+
"vle8.v v8, (%[ys])\n\t"
1335+
"addi %[q42], %[q42], 64\n\t"
13481336
"addi %[q83], %[ys], 48\n\t"
13491337
"addi %[q43], %[q43], 64\n\t"
1350-
"vle8.v v8, (%[q80])\n\t"
13511338
"vsrl.vi v4, v0, 4\n\t"
1352-
"addi %[q80], %[q80], 64\n\t"
13531339
"vle8.v v9, (%[q81])\n\t"
1340+
"vle8.v v10, (%[q82])\n\t"
13541341
"vand.vi v0, v0, 0xF\n\t"
13551342
"addi %[q81], %[q81], 64\n\t"
1356-
"vle8.v v10, (%[q82])\n\t"
13571343
"vsrl.vi v5, v1, 4\n\t"
13581344
"addi %[q82], %[q82], 64\n\t"
13591345
"vle8.v v11, (%[q83])\n\t"
1346+
"vle8.v v12, (%[q80])\n\t"
13601347
"vand.vi v1, v1, 0xF\n\t"
13611348
"addi %[q83], %[q83], 64\n\t"
1362-
"vle8.v v12, (%[q80])\n\t"
13631349
"vsrl.vi v6, v2, 4\n\t"
13641350
"addi %[q80], %[q80], 64\n\t"
13651351
"vle8.v v13, (%[q81])\n\t"
1352+
"vle8.v v14, (%[q82])\n\t"
13661353
"vand.vi v2, v2, 0xF\n\t"
13671354
"addi %[q81], %[q81], 64\n\t"
1368-
"vle8.v v14, (%[q82])\n\t"
13691355
"vsrl.vi v7, v3, 4\n\t"
13701356
"addi %[q82], %[q82], 64\n\t"
1357+
"vwmul.vv v16, v0, v8\n\t"
13711358
"vle8.v v15, (%[q83])\n\t"
1359+
"vle8.v v0, (%[q40])\n\t"
13721360
"vand.vi v3, v3, 0xF\n\t"
13731361
"addi %[q83], %[q83], 64\n\t"
1374-
"vwmul.vv v16, v0, v8\n\t"
13751362
"vwmul.vv v24, v2, v12\n\t"
13761363
"vwmul.vv v20, v4, v10\n\t"
13771364
"vwmul.vv v28, v6, v14\n\t"
13781365
"vwmacc.vv v16, v1, v9\n\t"
1366+
"vle8.v v1, (%[q41])\n\t"
1367+
"vle8.v v2, (%[q42])\n\t"
13791368
"vwmacc.vv v24, v3, v13\n\t"
13801369
"vwmacc.vv v20, v5, v11\n\t"
13811370
"vwmacc.vv v28, v7, v15\n\t"
1382-
"vle8.v v0, (%[q40])\n\t"
13831371
"addi %[q40], %[q80], 64\n\t"
1384-
"vle8.v v1, (%[q41])\n\t"
13851372
"addi %[q41], %[q81], 64\n\t"
1386-
"vle8.v v2, (%[q42])\n\t"
1387-
"addi %[q42], %[q82], 64\n\t"
13881373
"vle8.v v3, (%[q43])\n\t"
1389-
"addi %[q43], %[q83], 64\n\t"
13901374
"vle8.v v8, (%[q80])\n\t"
1375+
"addi %[q42], %[q82], 64\n\t"
1376+
"addi %[q43], %[q83], 64\n\t"
13911377
"vsrl.vi v4, v0, 4\n\t"
13921378
"vle8.v v9, (%[q81])\n\t"
1393-
"vand.vi v0, v0, 0xF\n\t"
13941379
"vle8.v v10, (%[q82])\n\t"
1380+
"vand.vi v0, v0, 0xF\n\t"
13951381
"vsrl.vi v5, v1, 4\n\t"
1382+
"vsrl.vi v7, v3, 4\n\t"
1383+
"vand.vi v3, v3, 0xF\n\t"
13961384
"vle8.v v11, (%[q83])\n\t"
1397-
"vand.vi v1, v1, 0xF\n\t"
13981385
"vle8.v v12, (%[q40])\n\t"
1386+
"vand.vi v1, v1, 0xF\n\t"
13991387
"vsrl.vi v6, v2, 4\n\t"
1400-
"vle8.v v13, (%[q41])\n\t"
14011388
"vand.vi v2, v2, 0xF\n\t"
1402-
"vle8.v v14, (%[q42])\n\t"
1403-
"vsrl.vi v7, v3, 4\n\t"
1404-
"vle8.v v15, (%[q43])\n\t"
1405-
"vand.vi v3, v3, 0xF\n\t"
14061389
"vwmul.vv v18, v0, v8\n\t"
1390+
"vle8.v v13, (%[q41])\n\t"
1391+
"vle8.v v14, (%[q42])\n\t"
14071392
"vwmul.vv v26, v2, v12\n\t"
14081393
"vwmul.vv v22, v4, v10\n\t"
14091394
"vwmul.vv v30, v6, v14\n\t"
14101395
"vwmacc.vv v18, v1, v9\n\t"
1396+
"vle8.v v15, (%[q43])\n\t"
14111397
"vwmacc.vv v26, v3, v13\n\t"
14121398
"vwmacc.vv v22, v5, v11\n\t"
14131399
"vwmacc.vv v30, v7, v15\n\t"
@@ -1444,12 +1430,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
14441430
"vfmv.f.s %[ftmp], v1\n\t"
14451431
"fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
14461432
"fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
1447-
: [tmp] "=&r" (tmp), [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
1433+
: [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
14481434
, [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
1449-
, [q40] "+&r" (q40), [q41] "+&r" (q41), [q42] "+&r" (q42), [q43] "+&r" (q43)
1435+
, [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43)
14501436
, [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
1451-
, [scale] "+&r" (scale)
1452-
: [d] "f" (d), [ys] "r" (y[i].qs)
1437+
: [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales)
1438+
, [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1439+
, [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin)
1440+
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
14531441
: "memory"
14541442
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
14551443
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"

0 commit comments

Comments
 (0)