Skip to content

Commit 8f2a5af

Browse files
committed
ggml-cpu : optimize rvv ggml_vec_dot_q6_K_q8_K
1 parent 624b291 commit 8f2a5af

File tree

1 file changed

+63
-28
lines changed

1 file changed

+63
-28
lines changed

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
12701270
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
12711271
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
12721272

1273-
int tmp, tmp2;
12741273
float ftmp, ft2;
12751274
const uint8_t * restrict q40;
12761275
const uint8_t * restrict q41;
@@ -1778,23 +1777,59 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
17781777

17791778
const int8_t * restrict scale = x[i].scales;
17801779

1781-
int sum_t = 0;
1782-
int t0;
1780+
int q6h;
1781+
float ftmp;
17831782

17841783
for (int j = 0; j < QK_K/128; ++j) {
17851784
__asm__ __volatile__(
1785+
"addi %[q6h], %[q6], 32\n\t"
1786+
"ld t0, 0(%[scale])\n\t"
1787+
"addi %[scale], %[scale], 8\n\t"
1788+
"slli t6, t0, 1 * 8\n\t"
1789+
"lb zero, 0(%[q6])\n\t"
1790+
"slli t5, t0, 2 * 8\n\t"
1791+
"slli t4, t0, 3 * 8\n\t"
1792+
"lb zero, 0(%[q6h])\n\t"
1793+
"slli t3, t0, 4 * 8\n\t"
1794+
"slli t2, t0, 5 * 8\n\t"
1795+
"lb zero, 0(%[qh])\n\t"
1796+
"lb zero, 31(%[q6h])\n\t"
1797+
"slli t1, t0, 6 * 8\n\t"
1798+
"srai a7, t0, 56\n\t"
17861799
"vsetvli zero, %[vl32], e8, m2\n\t"
1800+
"vle8.v v8, (%[q6])\n\t"
1801+
"srai t6, t6, 56\n\t"
1802+
"srai t5, t5, 56\n\t"
1803+
"srai t4, t4, 56\n\t"
1804+
"srai t3, t3, 56\n\t"
1805+
"vle8.v v10, (%[q6h])\n\t"
1806+
"addi %[q6], %[q6], 64\n\t"
1807+
"slli t0, t0, 7 * 8\n\t"
1808+
"srai t2, t2, 56\n\t"
1809+
"srai t1, t1, 56\n\t"
1810+
"srai t0, t0, 56\n\t"
17871811
"vle8.v v4, (%[qh])\n\t"
1812+
"vsrl.vi v12, v8, 4\n\t"
1813+
"vsrl.vi v14, v10, 4\n\t"
1814+
"lb zero, 0(%[q8])\n\t"
1815+
"vand.vi v8, v8, 0xF\n\t"
1816+
"vand.vi v10, v10, 0xF\n\t"
1817+
"lb zero, 32(%[q8])\n\t"
17881818
"vsll.vi v0, v4, 4\n\t"
17891819
"vsll.vi v2, v4, 2\n\t"
1820+
"lb zero, 64(%[q8])\n\t"
17901821
"vsrl.vi v6, v4, 2\n\t"
1791-
"vsetvli zero, %[vl64], e8, m4\n\t"
1792-
"vle8.v v8, (%[q6])\n\t"
1793-
"vsrl.vi v12, v8, 4\n\t"
1794-
"vand.vi v8, v8, 0xF\n\t"
1795-
"vsetvli zero, %[vl128], e8, m8\n\t"
17961822
"vand.vx v0, v0, %[mask]\n\t"
1823+
"lb zero, 96(%[q8])\n\t"
1824+
"vand.vx v2, v2, %[mask]\n\t"
1825+
"vand.vx v4, v4, %[mask]\n\t"
1826+
"vand.vx v6, v6, %[mask]\n\t"
17971827
"vor.vv v8, v8, v0\n\t"
1828+
"lb zero, 127(%[q8])\n\t"
1829+
"vor.vv v10, v10, v2\n\t"
1830+
"vor.vv v12, v12, v4\n\t"
1831+
"vor.vv v14, v14, v6\n\t"
1832+
"vsetvli zero, %[vl128], e8, m8\n\t"
17981833
"vle8.v v0, (%[q8])\n\t"
17991834
"vsub.vx v8, v8, %[vl32]\n\t"
18001835
"vsetvli zero, %[vl64], e8, m4\n\t"
@@ -1811,34 +1846,34 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
18111846
"vwredsum.vs v13, v28, v0\n\t"
18121847
"vwredsum.vs v14, v30, v0\n\t"
18131848
"vsetivli zero, 4, e32, m1\n\t"
1814-
"vslideup.vi v10, v9, 1\n\t"
1815-
"vslideup.vi v8, v7, 1\n\t"
1816-
"vslideup.vi v11, v12, 1\n\t"
1817-
"vslideup.vi v13, v14, 1\n\t"
1818-
"vslideup.vi v10, v8, 2\n\t"
1819-
"vslideup.vi v11, v13, 2\n\t"
1820-
"vsetivli zero, 8, e32, m2\n\t"
1821-
"vle8.v v2, (%[scale])\n\t"
1822-
"vsext.vf4 v4, v2\n\t"
1823-
"vmul.vv v2, v4, v10\n\t"
1824-
"vredsum.vs v0, v2, v0\n\t"
1825-
"vmv.x.s %[t0], v0\n\t"
1826-
"add %[sumi], %[sumi], %[t0]"
1827-
: [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
1828-
: [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
1849+
"vmul.vx v0, v10, t0\n\t"
1850+
"vmul.vx v1, v9, t1\n\t"
1851+
"vmacc.vx v0, t2, v8\n\t"
1852+
"vmacc.vx v1, t3, v7\n\t"
1853+
"vmacc.vx v0, t4, v11\n\t"
1854+
"vmacc.vx v1, t5, v12\n\t"
1855+
"vmacc.vx v0, t6, v13\n\t"
1856+
"vmacc.vx v1, a7, v14\n\t"
1857+
"vadd.vv v0, v0, v1\n\t"
1858+
"vfcvt.f.x.v v0, v0\n\t"
1859+
"vfmv.f.s %[ftmp], v0\n\t"
1860+
"fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
1861+
: [q6] "+&r" (q6), [q6h] "=&r" (q6h)
1862+
, [scale] "+&r" (scale)
1863+
, [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp)
1864+
: [qh] "r" (qh), [q8] "r" (q8)
18291865
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1830-
, [mask] "r" (0x30)
1866+
, [mask] "r" (0x30), [d] "f" (d)
18311867
: "memory"
18321868
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
18331869
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
18341870
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
18351871
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1872+
, "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7"
1873+
, "a6", "a5", "a4", "a3"
18361874
);
1837-
q6 += 64; qh += 32; q8 += 128; scale += 8;
1875+
qh += 32; q8 += 128;
18381876
}
1839-
1840-
sumf += d * sum_t;
1841-
18421877
}
18431878
break;
18441879
default:

0 commit comments

Comments
 (0)