@@ -1270,15 +1270,18 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12701270 const float d = y [i ].d * GGML_CPU_FP16_TO_FP32 (x [i ].d );
12711271 const float dmin = y [i ].d * GGML_CPU_FP16_TO_FP32 (x [i ].dmin );
12721272
1273- int tmp , tmp2 , sumi ;
1273+ int tmp , tmp2 ;
1274+ float ftmp , ft2 ;
1275+
12741276 __asm__ __volatile__(
1275- "vsetivli zero, 12, e8, m1\n\t"
1276- "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1277- "vsetivli zero, 4, e32, m1\n\t"
1277+ "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1278+ "vle32.v v1, (%[s6b])\n\t"
1279+ "vslide1down.vx v1, v1, zero\n\t"
1280+ "vmv.v.x v16, zero\n\t"
12781281 "vslidedown.vi v2, v1, 2\n\t"
12791282 "vmv1r.v v3, v2\n\t"
12801283 "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1281- "vsetivli zero, 2, e32, m1\n\t"
1284+ "vsetivli zero, 2, e32, m1, ta, ma \n\t"
12821285 "vmv.v.i v4, 4\n\t"
12831286 "vand.vx v8, v1, %[kmask1]\n\t"
12841287 "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
@@ -1292,81 +1295,167 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12921295 "vor.vv v1, v6, v2\n\t"
12931296 "vsse32.v v8, (%[utmp]), %[t2]\n\t"
12941297 "vsse32.v v1, (%[t1]), %[t2]\n\t"
1295- "vsetivli zero, 8, e16, m1\n\t"
1298+ "vsetivli zero, 8, e16, m1, ta, ma \n\t"
12961299 "vle32.v v2, (%[bsums])\n\t"
12971300 "vnsrl.wi v0, v2, 0\n\t"
12981301 "vnsrl.wi v1, v2, 16\n\t"
12991302 "vadd.vv v2, v0, v1\n\t"
13001303 "vle8.v v3, (%[mins])\n\t"
13011304 "vzext.vf2 v4, v3\n\t"
13021305 "vwmul.vv v6, v4, v2\n\t"
1303- "vmv.v.x v0, zero\n\t"
1304- "vsetivli zero, 8, e32, m2\n\t"
1305- "vredsum.vs v0, v6, v0\n\t"
1306- "vmv.x.s %[sumi], v0"
1307- : [t1 ] "=&r" (tmp ), [t2 ] "=&r" (tmp2 ), [sumi ] "=&r" (sumi )
1306+ "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1307+ "vredsum.vs v0, v6, v16\n\t"
1308+ "vredsum.vs v0, v7, v0\n\t"
1309+ "vfcvt.f.x.v v0, v0\n\t"
1310+ "vfmv.f.s %[ftmp], v0"
1311+ : [t1 ] "=&r" (tmp ), [t2 ] "=&r" (tmp2 ), [ftmp ] "=&f" (ftmp )
13081312 : [bsums ] "r" (y [i ].bsums ), [mins ] "r" (mins ), [utmp ] "r" (utmp )
1309- , [s6b ] "r" (x [i ]. scales ), [kmask1 ] "r" (kmask1 )
1313+ , [s6b ] "r" (& x [i ]), [kmask1 ] "r" (kmask1 )
13101314 , [kmask2 ] "r" (kmask2 ), [kmask3 ] "r" (kmask3 )
13111315 : "memory"
13121316 , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7"
13131317 , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15"
13141318 , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23"
13151319 , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31"
13161320 );
1317- sumf -= dmin * sumi ;
1318-
1319- const uint8_t * restrict q4 = x [i ].qs ;
1320- const int8_t * restrict q8 = y [i ].qs ;
1321-
1322- sumi = 0 ;
1321+ sumf -= dmin * ftmp ;
1322+
1323+ const uint8_t * restrict q40 = x [i ].qs + 0 ;
1324+ const uint8_t * restrict q41 = x [i ].qs + 16 ;
1325+ const uint8_t * restrict q42 = x [i ].qs + 32 ;
1326+ const uint8_t * restrict q43 = x [i ].qs + 48 ;
1327+ const int8_t * restrict q80 ;
1328+ const int8_t * restrict q81 ;
1329+ const int8_t * restrict q82 ;
1330+ const int8_t * restrict q83 ;
1331+
1332+ ftmp = 0 ;
13231333 const uint8_t * scale = scales ;
13241334
1325- for (int j = 0 ; j < QK_K /128 ; ++ j ) {
1326- int vl128 = 128 , vl64 = 64 , vl32 = 32 ;
1327- __asm__ __volatile__(
1328- "vsetvli zero, %[vl128], e8, m8\n\t"
1329- "vle8.v v8, (%[q8])\n\t"
1330- "vsetvli zero, %[vl64], e8, m4\n\t"
1331- "vle8.v v0, (%[q4])\n\t"
1332- "vsrl.vi v4, v0, 4\n\t"
1333- "vand.vi v0, v0, 0xF\n\t"
1334- "vsetvli zero, %[vl32], e8, m2\n\t"
1335- "vwmul.vv v28, v6, v14\n\t"
1336- "vwmul.vv v20, v4, v10\n\t"
1337- "vwmul.vv v24, v2, v12\n\t"
1338- "vwmul.vv v16, v0, v8\n\t"
1339- "vsetivli zero, 4, e32, m1\n\t"
1340- "vle8.v v2, (%[scale])\n\t"
1341- "vmv.v.x v0, zero\n\t"
1342- "vzext.vf4 v1, v2\n\t"
1343- "vsetvli zero, %[vl32], e16, m4\n\t"
1344- "vwredsum.vs v6, v24, v0\n\t"
1345- "vwredsum.vs v7, v28, v0\n\t"
1346- "vwredsum.vs v4, v16, v0\n\t"
1347- "vwredsum.vs v5, v20, v0\n\t"
1348- "vsetivli zero, 4, e32, m1\n\t"
1349- "vslideup.vi v6, v7, 1\n\t"
1350- "vslideup.vi v4, v5, 1\n\t"
1351- "vslideup.vi v4, v6, 2\n\t"
1352- "vmul.vv v8, v4, v1\n\t"
1353- "vredsum.vs v0, v8, v0\n\t"
1354- "vmv.x.s %[tmp], v0\n\t"
1355- "add %[sumi], %[sumi], %[tmp]"
1356- : [tmp ] "=&r" (tmp ), [sumi ] "+&r" (sumi )
1357- : [vl128 ] "r" (vl128 ), [vl64 ] "r" (vl64 ), [vl32 ] "r" (vl32 )
1358- , [q4 ] "r" (q4 ), [q8 ] "r" (q8 ), [scale ] "r" (scale )
1359- : "memory"
1360- , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7"
1361- , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15"
1362- , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23"
1363- , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31"
1364- );
1365-
1366- q4 += 64 ; q8 += 128 ; scale += 4 ;
1367- }
1368-
1369- sumf += d * sumi ;
1335+ int s0 , s1 , s2 , s3 ;
1336+ __asm__ __volatile__(
1337+ "vsetivli zero, 16, e8, m1, ta, ma\n\t"
1338+ "vle8.v v0, (%[q40])\n\t"
1339+ "addi %[q80], %[ys], 0\n\t"
1340+ "addi %[q40], %[q40], 64\n\t"
1341+ "vle8.v v1, (%[q41])\n\t"
1342+ "addi %[q81], %[ys], 16\n\t"
1343+ "addi %[q41], %[q41], 64\n\t"
1344+ "vle8.v v2, (%[q42])\n\t"
1345+ "addi %[q82], %[ys], 32\n\t"
1346+ "addi %[q42], %[q42], 64\n\t"
1347+ "vle8.v v3, (%[q43])\n\t"
1348+ "addi %[q83], %[ys], 48\n\t"
1349+ "addi %[q43], %[q43], 64\n\t"
1350+ "vle8.v v8, (%[q80])\n\t"
1351+ "vsrl.vi v4, v0, 4\n\t"
1352+ "addi %[q80], %[q80], 64\n\t"
1353+ "vle8.v v9, (%[q81])\n\t"
1354+ "vand.vi v0, v0, 0xF\n\t"
1355+ "addi %[q81], %[q81], 64\n\t"
1356+ "vle8.v v10, (%[q82])\n\t"
1357+ "vsrl.vi v5, v1, 4\n\t"
1358+ "addi %[q82], %[q82], 64\n\t"
1359+ "vle8.v v11, (%[q83])\n\t"
1360+ "vand.vi v1, v1, 0xF\n\t"
1361+ "addi %[q83], %[q83], 64\n\t"
1362+ "vle8.v v12, (%[q80])\n\t"
1363+ "vsrl.vi v6, v2, 4\n\t"
1364+ "addi %[q80], %[q80], 64\n\t"
1365+ "vle8.v v13, (%[q81])\n\t"
1366+ "vand.vi v2, v2, 0xF\n\t"
1367+ "addi %[q81], %[q81], 64\n\t"
1368+ "vle8.v v14, (%[q82])\n\t"
1369+ "vsrl.vi v7, v3, 4\n\t"
1370+ "addi %[q82], %[q82], 64\n\t"
1371+ "vle8.v v15, (%[q83])\n\t"
1372+ "vand.vi v3, v3, 0xF\n\t"
1373+ "addi %[q83], %[q83], 64\n\t"
1374+ "vwmul.vv v16, v0, v8\n\t"
1375+ "vwmul.vv v24, v2, v12\n\t"
1376+ "vwmul.vv v20, v4, v10\n\t"
1377+ "vwmul.vv v28, v6, v14\n\t"
1378+ "vwmacc.vv v16, v1, v9\n\t"
1379+ "vwmacc.vv v24, v3, v13\n\t"
1380+ "vwmacc.vv v20, v5, v11\n\t"
1381+ "vwmacc.vv v28, v7, v15\n\t"
1382+ "vle8.v v0, (%[q40])\n\t"
1383+ "addi %[q40], %[q80], 64\n\t"
1384+ "vle8.v v1, (%[q41])\n\t"
1385+ "addi %[q41], %[q81], 64\n\t"
1386+ "vle8.v v2, (%[q42])\n\t"
1387+ "addi %[q42], %[q82], 64\n\t"
1388+ "vle8.v v3, (%[q43])\n\t"
1389+ "addi %[q43], %[q83], 64\n\t"
1390+ "vle8.v v8, (%[q80])\n\t"
1391+ "vsrl.vi v4, v0, 4\n\t"
1392+ "vle8.v v9, (%[q81])\n\t"
1393+ "vand.vi v0, v0, 0xF\n\t"
1394+ "vle8.v v10, (%[q82])\n\t"
1395+ "vsrl.vi v5, v1, 4\n\t"
1396+ "vle8.v v11, (%[q83])\n\t"
1397+ "vand.vi v1, v1, 0xF\n\t"
1398+ "vle8.v v12, (%[q40])\n\t"
1399+ "vsrl.vi v6, v2, 4\n\t"
1400+ "vle8.v v13, (%[q41])\n\t"
1401+ "vand.vi v2, v2, 0xF\n\t"
1402+ "vle8.v v14, (%[q42])\n\t"
1403+ "vsrl.vi v7, v3, 4\n\t"
1404+ "vle8.v v15, (%[q43])\n\t"
1405+ "vand.vi v3, v3, 0xF\n\t"
1406+ "vwmul.vv v18, v0, v8\n\t"
1407+ "vwmul.vv v26, v2, v12\n\t"
1408+ "vwmul.vv v22, v4, v10\n\t"
1409+ "vwmul.vv v30, v6, v14\n\t"
1410+ "vwmacc.vv v18, v1, v9\n\t"
1411+ "vwmacc.vv v26, v3, v13\n\t"
1412+ "vwmacc.vv v22, v5, v11\n\t"
1413+ "vwmacc.vv v30, v7, v15\n\t"
1414+ "vmv.v.x v0, zero\n\t"
1415+ "vsetivli zero, 16, e16, m2, ta, ma\n\t"
1416+ "vwredsum.vs v4, v16, v0\n\t"
1417+ "lbu %[s0], 0(%[scale])\n\t"
1418+ "vwredsum.vs v5, v20, v0\n\t"
1419+ "lbu %[s1], 1(%[scale])\n\t"
1420+ "vwredsum.vs v6, v24, v0\n\t"
1421+ "lbu %[s2], 2(%[scale])\n\t"
1422+ "vwredsum.vs v7, v28, v0\n\t"
1423+ "lbu %[s3], 3(%[scale])\n\t"
1424+ "vwredsum.vs v8, v18, v0\n\t"
1425+ "lbu %[q40], 4(%[scale])\n\t"
1426+ "vwredsum.vs v9, v22, v0\n\t"
1427+ "lbu %[q41], 5(%[scale])\n\t"
1428+ "vwredsum.vs v10, v26, v0\n\t"
1429+ "lbu %[q42], 6(%[scale])\n\t"
1430+ "vwredsum.vs v11, v30, v0\n\t"
1431+ "lbu %[q43], 7(%[scale])\n\t"
1432+ "vsetivli zero, 4, e32, m1, ta, ma\n\t"
1433+ "vmul.vx v0, v4, %[s0]\n\t"
1434+ "vmul.vx v1, v8, %[q40]\n\t"
1435+ "vmacc.vx v0, %[s1], v5\n\t"
1436+ "vmacc.vx v1, %[q41], v9\n\t"
1437+ "vmacc.vx v0, %[s2], v6\n\t"
1438+ "vmacc.vx v1, %[q42], v10\n\t"
1439+ "vmacc.vx v0, %[s3], v7\n\t"
1440+ "vmacc.vx v1, %[q43], v11\n\t"
1441+ "vfcvt.f.x.v v0, v0\n\t"
1442+ "vfcvt.f.x.v v1, v1\n\t"
1443+ "vfmv.f.s %[ft2], v0\n\t"
1444+ "vfmv.f.s %[ftmp], v1\n\t"
1445+ "fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
1446+ "fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
1447+ : [tmp ] "=&r" (tmp ), [ftmp ] "=&f" (ftmp ), [sumf ] "+&f" (sumf ), [ft2 ] "=&f" (ft2 )
1448+ , [s0 ] "=&r" (s0 ), [s1 ] "=&r" (s1 ), [s2 ] "=&r" (s2 ), [s3 ] "=&r" (s3 )
1449+ , [q40 ] "+&r" (q40 ), [q41 ] "+&r" (q41 ), [q42 ] "+&r" (q42 ), [q43 ] "+&r" (q43 )
1450+ , [q80 ] "=&r" (q80 ), [q81 ] "=&r" (q81 ), [q82 ] "=&r" (q82 ), [q83 ] "=&r" (q83 )
1451+ , [scale ] "+&r" (scale )
1452+ : [d ] "f" (d ), [ys ] "r" (y [i ].qs )
1453+ : "memory"
1454+ , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7"
1455+ , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15"
1456+ , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23"
1457+ , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31"
1458+ );
13701459 }
13711460 break ;
13721461 default :
0 commit comments