Skip to content

Commit 16af48f

Browse files
committed
multi-thread across src1 rows
1 parent dafedd3 commit 16af48f

File tree

3 files changed

+51
-28
lines changed

3 files changed

+51
-28
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,15 +1955,15 @@ static void ggml_metal_encode_node(
19551955
}
19561956
#endif
19571957

1958-
if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 4 && ne11 < 32)) {
1958+
if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 2 && ne11 < 32)) {
19591959
//if (false) {
19601960
id<MTLComputePipelineState> pipeline = nil;
19611961

19621962
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
19631963

19641964
const int nsg = 2;
19651965
const int r0pt = 1;
1966-
const int r1pt = 1;
1966+
const int r1pt = 4;
19671967
const int nxpsg = ne11 > 1 ? 8 : 32;
19681968
const int nypsg = 32/nxpsg;
19691969
const int nr0ptg = nypsg*r0pt*nsg;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,6 +1784,7 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17841784
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
17851785
const short chpt = 4;
17861786
const short r0pt = 1;
1787+
const short r1pt = 4;
17871788

17881789
//const short nxpsg = (32);
17891790
const short nypsg = (32/nxpsg)*r0pt;
@@ -1792,7 +1793,7 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17921793
const short ty = tiisg/nxpsg;
17931794

17941795
const int i01 = tgpig.x*(nypsg*nsg) + nypsg*sgitg + ty*r0pt;
1795-
const int i11 = tgpig.y;
1796+
const int i11 = tgpig.y*r1pt;
17961797
const int i1m = tgpig.z;
17971798

17981799
const int i12 = i1m%args.ne12;
@@ -1801,17 +1802,24 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
18011802
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
18021803
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
18031804

1805+
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
1806+
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
1807+
18041808
device const block_q8_0 * xq[r0pt];
18051809

18061810
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
18071811
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
18081812
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
18091813
}
18101814

1811-
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
1812-
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
1815+
device const float4 * y4[r1pt];
1816+
for (int ir1 = 0; ir1 < r1pt; ++ir1) {
1817+
//y4[ir1] = (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx;
1818+
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
1819+
}
18131820

1814-
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
1821+
//float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
1822+
float sumf[r1pt][r0pt] = { [ 0 ... r1pt - 1 ] = { [0 ... r0pt - 1] = 0.0f } };
18151823

18161824
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
18171825
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
@@ -1821,42 +1829,53 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
18211829

18221830
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
18231831

1824-
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
1832+
#pragma unroll(4)
1833+
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
1834+
sumf[ir1][ir0] += dot(lx, y4[ir1][ch*nxpsg]);
1835+
}
18251836
}
18261837
}
18271838

1828-
y4 += ((4*chpt)*nxpsg)/4;
1839+
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
1840+
y4[ir1] += ((4*chpt)*nxpsg)/4;
1841+
}
18291842

18301843
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
18311844
xq[ir0] += ((4*chpt)*nxpsg)/32;
18321845
}
18331846
}
18341847

1835-
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1836-
if (nxpsg >= 32) {
1837-
sumf[ir0] += simd_shuffle_down(sumf[ir0], 16);
1838-
}
1839-
if (nxpsg >= 16) {
1840-
sumf[ir0] += simd_shuffle_down(sumf[ir0], 8);
1841-
}
1842-
if (nxpsg >= 8) {
1843-
sumf[ir0] += simd_shuffle_down(sumf[ir0], 4);
1844-
}
1845-
if (nxpsg >= 4) {
1846-
sumf[ir0] += simd_shuffle_down(sumf[ir0], 2);
1847-
}
1848-
if (nxpsg >= 2) {
1849-
sumf[ir0] += simd_shuffle_down(sumf[ir0], 1);
1850-
}
1848+
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
1849+
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1850+
if (nxpsg >= 32) {
1851+
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 16);
1852+
}
1853+
if (nxpsg >= 16) {
1854+
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 8);
1855+
}
1856+
if (nxpsg >= 8) {
1857+
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 4);
1858+
}
1859+
if (nxpsg >= 4) {
1860+
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 2);
1861+
}
1862+
if (nxpsg >= 2) {
1863+
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 1);
1864+
}
18511865

1852-
//sumf[ir0] = simd_sum(sumf[ir0]);
1866+
//sumf[ir1][ir0] = simd_sum(sumf[ir1][ir0]);
1867+
}
18531868
}
18541869

1855-
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
1870+
//device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
18561871

18571872
if (tx == 0) {
1858-
for (short ir0 = 0; ir0 < r0pt && i01 + ir0 < args.ne01; ++ir0) {
1859-
dst_f32[i01 + ir0] = sumf[ir0];
1873+
for (short ir1 = 0; ir1 < r1pt && i11 + ir1 < args.ne11; ++ir1) {
1874+
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
1875+
1876+
for (short ir0 = 0; ir0 < r0pt && i01 + ir0 < args.ne01; ++ir0) {
1877+
dst_f32[i01 + ir0] = sumf[ir1][ir0];
1878+
}
18601879
}
18611880
}
18621881
}

tests/test-backend-ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3570,6 +3570,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
35703570
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
35713571
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
35723572

3573+
for (int i = 1; i < 64; ++i) {
3574+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 64, i, 256, { 1, 1}, {1, 1}));
3575+
}
3576+
35733577
#if 1
35743578
for (ggml_type type_a : base_types) {
35753579
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {

0 commit comments

Comments
 (0)