@@ -1270,179 +1270,103 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12701270 const float d = y [i ].d * GGML_CPU_FP16_TO_FP32 (x [i ].d );
12711271 const float dmin = y [i ].d * GGML_CPU_FP16_TO_FP32 (x [i ].dmin );
12721272
1273- float ftmp , ft2 ;
1274- const uint8_t * restrict q40 ;
1275- const uint8_t * restrict q41 ;
1276- const uint8_t * restrict q42 ;
1277- const uint8_t * restrict q43 ;
1278- const int8_t * restrict q80 ;
1279- const int8_t * restrict q81 ;
1280- const int8_t * restrict q82 ;
1281- const int8_t * restrict q83 ;
1282- int s0 , s1 , s2 , s3 ;
1283-
1273+ int tmp , tmp2 , sumi ;
12841274 __asm__ __volatile__(
1285- "li %[s1], 8\n\t"
1286- "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1287- "vle32.v v1, (%[s6b])\n\t"
1288- "vslide1down.vx v1, v1, zero\n\t"
1289- "vmv.v.x v16, zero\n\t"
1275+ "vsetivli zero, 12, e8, m1\n\t"
1276+ "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1277+ "vsetivli zero, 4, e32, m1\n\t"
12901278 "vslidedown.vi v2, v1, 2\n\t"
12911279 "vmv1r.v v3, v2\n\t"
12921280 "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1293- "vsetivli zero, 2, e32, m1, ta, ma \n\t"
1281+ "vsetivli zero, 2, e32, m1\n\t"
12941282 "vmv.v.i v4, 4\n\t"
12951283 "vand.vx v8, v1, %[kmask1]\n\t"
12961284 "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
12971285 "vsrl.vi v6, v1, 6\n\t"
12981286 "vsrl.vv v7, v2, v5\n\t"
1299- "vsse32.v v8, (%[utmp]), %[s1]\n\t"
13001287 "vand.vx v0, v6, %[kmask3]\n\t"
13011288 "vand.vx v2, v7, %[kmask2]\n\t"
13021289 "vsll.vi v6, v0, 4\n\t"
1303- "addi %[s0], %[utmp], 4\n\t"
1290+ "li %[t2], 8\n\t"
1291+ "addi %[t1], %[utmp], 4\n\t"
13041292 "vor.vv v1, v6, v2\n\t"
1305- "vsse32.v v1, (%[s0]), %[s1]\n\t"
1306- "vsetivli zero, 8, e16, m1, ta, ma\n\t"
1293+ "vsse32.v v8, (%[utmp]), %[t2]\n\t"
1294+ "vsse32.v v1, (%[t1]), %[t2]\n\t"
1295+ "vsetivli zero, 8, e16, m1\n\t"
13071296 "vle32.v v2, (%[bsums])\n\t"
13081297 "vnsrl.wi v0, v2, 0\n\t"
13091298 "vnsrl.wi v1, v2, 16\n\t"
13101299 "vadd.vv v2, v0, v1\n\t"
13111300 "vle8.v v3, (%[mins])\n\t"
13121301 "vzext.vf2 v4, v3\n\t"
13131302 "vwmul.vv v6, v4, v2\n\t"
1314- "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1315- "vredsum.vs v0, v6, v16\n\t"
1316- "vredsum.vs v0, v7, v0\n\t"
1317- "vfcvt.f.x.v v0, v0\n\t"
1318- "vfmv.f.s %[ftmp], v0\n\t"
1319- "vsetivli zero, 16, e8, m1, ta, ma\n\t"
1320- "vle8.v v0, (%[xs])\n\t"
1321- "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
1322- "addi %[q40], %[xs], 64\n\t"
1323- "addi %[q41], %[xs], 16\n\t"
1324- "addi %[q42], %[xs], 32\n\t"
1325- "addi %[q43], %[xs], 48\n\t"
1326- "addi %[q80], %[ys], 64\n\t"
1327- "vle8.v v1, (%[q41])\n\t"
1328- "vle8.v v2, (%[q42])\n\t"
1329- "addi %[q81], %[ys], 16\n\t"
1330- "addi %[q41], %[q41], 64\n\t"
1331- "addi %[q82], %[ys], 32\n\t"
1332- "vle8.v v3, (%[q43])\n\t"
1333- "vle8.v v8, (%[ys])\n\t"
1334- "addi %[q42], %[q42], 64\n\t"
1335- "addi %[q83], %[ys], 48\n\t"
1336- "addi %[q43], %[q43], 64\n\t"
1337- "vsrl.vi v4, v0, 4\n\t"
1338- "vle8.v v9, (%[q81])\n\t"
1339- "vle8.v v10, (%[q82])\n\t"
1340- "vand.vi v0, v0, 0xF\n\t"
1341- "addi %[q81], %[q81], 64\n\t"
1342- "vsrl.vi v5, v1, 4\n\t"
1343- "addi %[q82], %[q82], 64\n\t"
1344- "vle8.v v11, (%[q83])\n\t"
1345- "vle8.v v12, (%[q80])\n\t"
1346- "vand.vi v1, v1, 0xF\n\t"
1347- "addi %[q83], %[q83], 64\n\t"
1348- "vsrl.vi v6, v2, 4\n\t"
1349- "addi %[q80], %[q80], 64\n\t"
1350- "vle8.v v13, (%[q81])\n\t"
1351- "vle8.v v14, (%[q82])\n\t"
1352- "vand.vi v2, v2, 0xF\n\t"
1353- "addi %[q81], %[q81], 64\n\t"
1354- "vsrl.vi v7, v3, 4\n\t"
1355- "addi %[q82], %[q82], 64\n\t"
1356- "vwmul.vv v16, v0, v8\n\t"
1357- "vle8.v v15, (%[q83])\n\t"
1358- "vle8.v v0, (%[q40])\n\t"
1359- "vand.vi v3, v3, 0xF\n\t"
1360- "addi %[q83], %[q83], 64\n\t"
1361- "vwmul.vv v24, v2, v12\n\t"
1362- "vwmul.vv v20, v4, v10\n\t"
1363- "vwmul.vv v28, v6, v14\n\t"
1364- "vwmacc.vv v16, v1, v9\n\t"
1365- "vle8.v v1, (%[q41])\n\t"
1366- "vle8.v v2, (%[q42])\n\t"
1367- "vwmacc.vv v24, v3, v13\n\t"
1368- "vwmacc.vv v20, v5, v11\n\t"
1369- "vwmacc.vv v28, v7, v15\n\t"
1370- "addi %[q40], %[q80], 64\n\t"
1371- "addi %[q41], %[q81], 64\n\t"
1372- "vle8.v v3, (%[q43])\n\t"
1373- "vle8.v v8, (%[q80])\n\t"
1374- "addi %[q42], %[q82], 64\n\t"
1375- "addi %[q43], %[q83], 64\n\t"
1376- "vsrl.vi v4, v0, 4\n\t"
1377- "vle8.v v9, (%[q81])\n\t"
1378- "vle8.v v10, (%[q82])\n\t"
1379- "vand.vi v0, v0, 0xF\n\t"
1380- "vsrl.vi v5, v1, 4\n\t"
1381- "vsrl.vi v7, v3, 4\n\t"
1382- "vand.vi v3, v3, 0xF\n\t"
1383- "vle8.v v11, (%[q83])\n\t"
1384- "vle8.v v12, (%[q40])\n\t"
1385- "vand.vi v1, v1, 0xF\n\t"
1386- "vsrl.vi v6, v2, 4\n\t"
1387- "vand.vi v2, v2, 0xF\n\t"
1388- "vwmul.vv v18, v0, v8\n\t"
1389- "vle8.v v13, (%[q41])\n\t"
1390- "vle8.v v14, (%[q42])\n\t"
1391- "vwmul.vv v26, v2, v12\n\t"
1392- "vwmul.vv v22, v4, v10\n\t"
1393- "vwmul.vv v30, v6, v14\n\t"
1394- "vwmacc.vv v18, v1, v9\n\t"
1395- "vle8.v v15, (%[q43])\n\t"
1396- "vwmacc.vv v26, v3, v13\n\t"
1397- "vwmacc.vv v22, v5, v11\n\t"
1398- "vwmacc.vv v30, v7, v15\n\t"
13991303 "vmv.v.x v0, zero\n\t"
1400- "vsetivli zero, 16, e16, m2, ta, ma\n\t"
1401- "vwredsum.vs v4, v16, v0\n\t"
1402- "lbu %[s0], 0(%[scale])\n\t"
1403- "vwredsum.vs v5, v20, v0\n\t"
1404- "lbu %[s1], 1(%[scale])\n\t"
1405- "vwredsum.vs v6, v24, v0\n\t"
1406- "lbu %[s2], 2(%[scale])\n\t"
1407- "vwredsum.vs v7, v28, v0\n\t"
1408- "lbu %[s3], 3(%[scale])\n\t"
1409- "vwredsum.vs v8, v18, v0\n\t"
1410- "lbu %[q40], 4(%[scale])\n\t"
1411- "vwredsum.vs v9, v22, v0\n\t"
1412- "lbu %[q41], 5(%[scale])\n\t"
1413- "vwredsum.vs v10, v26, v0\n\t"
1414- "lbu %[q42], 6(%[scale])\n\t"
1415- "vwredsum.vs v11, v30, v0\n\t"
1416- "lbu %[q43], 7(%[scale])\n\t"
1417- "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1418- "vmul.vx v0, v4, %[s0]\n\t"
1419- "vmul.vx v1, v8, %[q40]\n\t"
1420- "vmacc.vx v0, %[s1], v5\n\t"
1421- "vmacc.vx v1, %[q41], v9\n\t"
1422- "vmacc.vx v0, %[s2], v6\n\t"
1423- "vmacc.vx v1, %[q42], v10\n\t"
1424- "vmacc.vx v0, %[s3], v7\n\t"
1425- "vmacc.vx v1, %[q43], v11\n\t"
1426- "vfcvt.f.x.v v0, v0\n\t"
1427- "vfcvt.f.x.v v1, v1\n\t"
1428- "vfmv.f.s %[ft2], v0\n\t"
1429- "vfmv.f.s %[ftmp], v1\n\t"
1430- "fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
1431- "fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
1432- : [ftmp ] "=&f" (ftmp ), [sumf ] "+&f" (sumf ), [ft2 ] "=&f" (ft2 )
1433- , [s0 ] "=&r" (s0 ), [s1 ] "=&r" (s1 ), [s2 ] "=&r" (s2 ), [s3 ] "=&r" (s3 )
1434- , [q40 ] "=&r" (q40 ), [q41 ] "=&r" (q41 ), [q42 ] "=&r" (q42 ), [q43 ] "=&r" (q43 )
1435- , [q80 ] "=&r" (q80 ), [q81 ] "=&r" (q81 ), [q82 ] "=&r" (q82 ), [q83 ] "=&r" (q83 )
1436- : [d ] "f" (d ), [ys ] "r" (y [i ].qs ), [xs ] "r" (x [i ].qs ), [scale ] "r" (scales )
1437- , [bsums ] "r" (y [i ].bsums ), [mins ] "r" (mins ), [utmp ] "r" (utmp )
1438- , [s6b ] "r" (& x [i ]), [kmask1 ] "r" (kmask1 ), [dmin ] "f" (dmin )
1304+ "vsetivli zero, 8, e32, m2\n\t"
1305+ "vredsum.vs v0, v6, v0\n\t"
1306+ "vmv.x.s %[sumi], v0"
1307+ : [t1 ] "=&r" (tmp ), [t2 ] "=&r" (tmp2 ), [sumi ] "=&r" (sumi )
1308+ : [bsums ] "r" (y [i ].bsums ), [mins ] "r" (mins ), [utmp ] "r" (utmp )
1309+ , [s6b ] "r" (x [i ].scales ), [kmask1 ] "r" (kmask1 )
14391310 , [kmask2 ] "r" (kmask2 ), [kmask3 ] "r" (kmask3 )
14401311 : "memory"
14411312 , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7"
14421313 , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15"
14431314 , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23"
14441315 , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31"
14451316 );
1317+ sumf -= dmin * sumi ;
1318+
1319+ const uint8_t * restrict q4 = x [i ].qs ;
1320+ const int8_t * restrict q8 = y [i ].qs ;
1321+
1322+ sumi = 0 ;
1323+ const uint8_t * scale = scales ;
1324+
1325+ for (int j = 0 ; j < QK_K /128 ; ++ j ) {
1326+ int vl128 = 128 , vl64 = 64 , vl32 = 32 ;
1327+ __asm__ __volatile__(
1328+ "vsetvli zero, %[vl128], e8, m8\n\t"
1329+ "vle8.v v8, (%[q8])\n\t"
1330+ "vsetvli zero, %[vl64], e8, m4\n\t"
1331+ "vle8.v v0, (%[q4])\n\t"
1332+ "vsrl.vi v4, v0, 4\n\t"
1333+ "vand.vi v0, v0, 0xF\n\t"
1334+ "vsetvli zero, %[vl32], e8, m2\n\t"
1335+ "vwmul.vv v28, v6, v14\n\t"
1336+ "vwmul.vv v20, v4, v10\n\t"
1337+ "vwmul.vv v24, v2, v12\n\t"
1338+ "vwmul.vv v16, v0, v8\n\t"
1339+ "vsetivli zero, 4, e32, m1\n\t"
1340+ "vle8.v v2, (%[scale])\n\t"
1341+ "vmv.v.x v0, zero\n\t"
1342+ "vzext.vf4 v1, v2\n\t"
1343+ "vsetvli zero, %[vl32], e16, m4\n\t"
1344+ "vwredsum.vs v6, v24, v0\n\t"
1345+ "vwredsum.vs v7, v28, v0\n\t"
1346+ "vwredsum.vs v4, v16, v0\n\t"
1347+ "vwredsum.vs v5, v20, v0\n\t"
1348+ "vsetivli zero, 4, e32, m1\n\t"
1349+ "vslideup.vi v6, v7, 1\n\t"
1350+ "vslideup.vi v4, v5, 1\n\t"
1351+ "vslideup.vi v4, v6, 2\n\t"
1352+ "vmul.vv v8, v4, v1\n\t"
1353+ "vredsum.vs v0, v8, v0\n\t"
1354+ "vmv.x.s %[tmp], v0\n\t"
1355+ "add %[sumi], %[sumi], %[tmp]"
1356+ : [tmp ] "=&r" (tmp ), [sumi ] "+&r" (sumi )
1357+ : [vl128 ] "r" (vl128 ), [vl64 ] "r" (vl64 ), [vl32 ] "r" (vl32 )
1358+ , [q4 ] "r" (q4 ), [q8 ] "r" (q8 ), [scale ] "r" (scale )
1359+ : "memory"
1360+ , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7"
1361+ , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15"
1362+ , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23"
1363+ , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31"
1364+ );
1365+
1366+ q4 += 64 ; q8 += 128 ; scale += 4 ;
1367+ }
1368+
1369+ sumf += d * sumi ;
14461370 }
14471371 break ;
14481372 default :
@@ -1769,8 +1693,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
17691693 case 128 :
17701694 for (int i = 0 ; i < nb ; ++ i ) {
17711695
1772- __builtin_prefetch (& x [i + 1 ].d , 0 , 1 );
1773-
17741696 const float d = GGML_CPU_FP16_TO_FP32 (x [i ].d ) * y [i ].d ;
17751697
17761698 const uint8_t * restrict q6 = x [i ].ql ;
@@ -1779,59 +1701,23 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
17791701
17801702 const int8_t * restrict scale = x [i ].scales ;
17811703
1782- int q6h ;
1783- float ftmp ;
1704+ int sum_t = 0 ;
1705+ int t0 ;
17841706
17851707 for (int j = 0 ; j < QK_K /128 ; ++ j ) {
17861708 __asm__ __volatile__(
1787- "addi %[q6h], %[q6], 32\n\t"
1788- "ld t0, 0(%[scale])\n\t"
1789- "addi %[scale], %[scale], 8\n\t"
1790- "slli t6, t0, 1 * 8\n\t"
1791- "lb zero, 0(%[q6])\n\t"
1792- "slli t5, t0, 2 * 8\n\t"
1793- "slli t4, t0, 3 * 8\n\t"
1794- "lb zero, 0(%[q6h])\n\t"
1795- "slli t3, t0, 4 * 8\n\t"
1796- "slli t2, t0, 5 * 8\n\t"
1797- "lb zero, 0(%[qh])\n\t"
1798- "lb zero, 31(%[q6h])\n\t"
1799- "slli t1, t0, 6 * 8\n\t"
1800- "srai a7, t0, 56\n\t"
18011709 "vsetvli zero, %[vl32], e8, m2\n\t"
1802- "vle8.v v8, (%[q6])\n\t"
1803- "srai t6, t6, 56\n\t"
1804- "srai t5, t5, 56\n\t"
1805- "srai t4, t4, 56\n\t"
1806- "srai t3, t3, 56\n\t"
1807- "vle8.v v10, (%[q6h])\n\t"
1808- "addi %[q6], %[q6], 64\n\t"
1809- "slli t0, t0, 7 * 8\n\t"
1810- "srai t2, t2, 56\n\t"
1811- "srai t1, t1, 56\n\t"
1812- "srai t0, t0, 56\n\t"
18131710 "vle8.v v4, (%[qh])\n\t"
1814- "vsrl.vi v12, v8, 4\n\t"
1815- "vsrl.vi v14, v10, 4\n\t"
1816- "lb zero, 0(%[q8])\n\t"
1817- "vand.vi v8, v8, 0xF\n\t"
1818- "vand.vi v10, v10, 0xF\n\t"
1819- "lb zero, 32(%[q8])\n\t"
18201711 "vsll.vi v0, v4, 4\n\t"
18211712 "vsll.vi v2, v4, 2\n\t"
1822- "lb zero, 64(%[q8])\n\t"
18231713 "vsrl.vi v6, v4, 2\n\t"
1714+ "vsetvli zero, %[vl64], e8, m4\n\t"
1715+ "vle8.v v8, (%[q6])\n\t"
1716+ "vsrl.vi v12, v8, 4\n\t"
1717+ "vand.vi v8, v8, 0xF\n\t"
1718+ "vsetvli zero, %[vl128], e8, m8\n\t"
18241719 "vand.vx v0, v0, %[mask]\n\t"
1825- "lb zero, 96(%[q8])\n\t"
1826- "vand.vx v2, v2, %[mask]\n\t"
1827- "vand.vx v4, v4, %[mask]\n\t"
1828- "vand.vx v6, v6, %[mask]\n\t"
18291720 "vor.vv v8, v8, v0\n\t"
1830- "lb zero, 127(%[q8])\n\t"
1831- "vor.vv v10, v10, v2\n\t"
1832- "vor.vv v12, v12, v4\n\t"
1833- "vor.vv v14, v14, v6\n\t"
1834- "vsetvli zero, %[vl128], e8, m8\n\t"
18351721 "vle8.v v0, (%[q8])\n\t"
18361722 "vsub.vx v8, v8, %[vl32]\n\t"
18371723 "vsetvli zero, %[vl64], e8, m4\n\t"
@@ -1848,34 +1734,34 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
18481734 "vwredsum.vs v13, v28, v0\n\t"
18491735 "vwredsum.vs v14, v30, v0\n\t"
18501736 "vsetivli zero, 4, e32, m1\n\t"
1851- "vmul.vx v0, v10, t0\n\t"
1852- "vmul.vx v1, v9, t1\n\t"
1853- "vmacc.vx v0, t2, v8\n\t"
1854- "vmacc.vx v1, t3, v7\n\t"
1855- "vmacc.vx v0, t4, v11\n\t"
1856- "vmacc.vx v1, t5, v12\n\t"
1857- "vmacc.vx v0, t6, v13\n\t"
1858- "vmacc.vx v1, a7, v14\n\t"
1859- "vadd.vv v0, v0, v1\n\t"
1860- "vfcvt.f.x.v v0, v0\n\t"
1861- "vfmv.f.s %[ftmp], v0\n\t"
1862- "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
1863- : [q6 ] "+&r" (q6 ), [q6h ] "=&r" (q6h )
1864- , [scale ] "+&r" (scale )
1865- , [sumf ] "+&f" (sumf ), [ftmp ] "=&f" (ftmp )
1866- : [qh ] "r" (qh ), [q8 ] "r" (q8 )
1737+ "vslideup.vi v10, v9, 1\n\t"
1738+ "vslideup.vi v8, v7, 1\n\t"
1739+ "vslideup.vi v11, v12, 1\n\t"
1740+ "vslideup.vi v13, v14, 1\n\t"
1741+ "vslideup.vi v10, v8, 2\n\t"
1742+ "vslideup.vi v11, v13, 2\n\t"
1743+ "vsetivli zero, 8, e32, m2\n\t"
1744+ "vle8.v v2, (%[scale])\n\t"
1745+ "vsext.vf4 v4, v2\n\t"
1746+ "vmul.vv v2, v4, v10\n\t"
1747+ "vredsum.vs v0, v2, v0\n\t"
1748+ "vmv.x.s %[t0], v0\n\t"
1749+ "add %[sumi], %[sumi], %[t0]"
1750+ : [sumi ] "+&r" (sum_t ), [t0 ] "=&r" (t0 )
1751+ : [qh ] "r" (qh ), [q6 ] "r" (q6 ), [q8 ] "r" (q8 ), [scale ] "r" (scale )
18671752 , [vl32 ] "r" (32 ), [vl64 ] "r" (64 ), [vl128 ] "r" (128 )
1868- , [mask ] "r" (0x30 ), [ d ] "f" ( d )
1753+ , [mask ] "r" (0x30 )
18691754 : "memory"
18701755 , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7"
18711756 , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15"
18721757 , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23"
18731758 , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31"
1874- , "t0" , "t1" , "t2" , "t3" , "t4" , "t5" , "t6" , "a7"
1875- , "a6" , "a5" , "a4" , "a3"
18761759 );
1877- qh += 32 ; q8 += 128 ;
1760+ q6 += 64 ; qh += 32 ; q8 += 128 ; scale += 8 ;
18781761 }
1762+
1763+ sumf += d * sum_t ;
1764+
18791765 }
18801766 break ;
18811767 default :
0 commit comments