Skip to content

Commit 3c8a2a8

Browse files
committed
shmem experiments
1 parent dafedd3 commit 3c8a2a8

File tree

2 files changed

+123
-24
lines changed

2 files changed

+123
-24
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,14 +1961,15 @@ static void ggml_metal_encode_node(
19611961

19621962
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
19631963

1964-
const int nsg = 2;
1965-
const int r0pt = 1;
1964+
const int nsg = 4;
1965+
const int r0pt = 4;
19661966
const int r1pt = 1;
1967-
const int nxpsg = ne11 > 1 ? 8 : 32;
1967+
//const int nxpsg = ne11 > 1 ? 8 : 32;
1968+
const int nxpsg = 32;
19681969
const int nypsg = 32/nxpsg;
19691970
const int nr0ptg = nypsg*r0pt*nsg;
19701971

1971-
//GGML_ASSERT(ne00%1024 == 0);
1972+
//GGML_ASSERT(ne00%4096 == 0);
19721973
//GGML_ASSERT(ne01%nr0ptg == 0);
19731974
//printf("ne01 = %lld, nr0ptg = %d, ne00 = %lld\n", ne01, nr0ptg, ne00);
19741975

@@ -2003,6 +2004,11 @@ static void ggml_metal_encode_node(
20032004

20042005
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
20052006
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2007+
2008+
[encoder setThreadgroupMemoryLength:2*8192 atIndex:0];
2009+
2010+
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2011+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
20062012
} else
20072013
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
20082014
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,27 @@ void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg)
190190
}
191191
}
192192

193+
template <typename type4>
194+
void dequantize_q8_0s(threadgroup const block_q8_0 * xb, short il, thread type4 & reg) {
195+
threadgroup const int8_t * qs = ((threadgroup const int8_t *) xb->qs);
196+
const float d = xb->d;
197+
198+
for (int i = 0; i < 4; i++) {
199+
reg[i] = (qs[4*(il%4) + i + 16*(il/4)]*d);
200+
}
201+
}
202+
203+
//template <typename type4>
204+
//type4 dequantize_q8_0x(device const int8_t * qs, float d, short il) {
205+
// thread type4 reg;
206+
// for (int i = 0; i < 4; i++) {
207+
// reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
208+
// //reg[i] = qs[i/2];
209+
// }
210+
//
211+
// return reg;
212+
//}
213+
193214
template <typename type4x4>
194215
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
195216
const float d = xb->d;
@@ -1778,12 +1799,13 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17781799
device const char * src0,
17791800
device const char * src1,
17801801
device char * dst,
1802+
threadgroup char * shmem [[threadgroup(0)]],
17811803
uint3 tgpig[[threadgroup_position_in_grid]],
17821804
ushort3 ntg[[threads_per_threadgroup]],
17831805
ushort tiisg[[thread_index_in_simdgroup]],
17841806
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1785-
const short chpt = 4;
1786-
const short r0pt = 1;
1807+
const short chpt = 8;
1808+
const short r0pt = 4;
17871809

17881810
//const short nxpsg = (32);
17891811
const short nypsg = (32/nxpsg)*r0pt;
@@ -1802,34 +1824,76 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
18021824
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
18031825

18041826
device const block_q8_0 * xq[r0pt];
1827+
device const block_q8_0 * xq0[r0pt];
18051828

18061829
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
18071830
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
1808-
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
1831+
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
1832+
xq0[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) : (device const block_q8_0 *) src0;
18091833
}
18101834

18111835
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
18121836
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
18131837

18141838
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
18151839

1840+
threadgroup block_q8_0 * shmem_q = (threadgroup block_q8_0 *) shmem + (((4*chpt)*nxpsg)/32)*r0pt*sgitg;
1841+
18161842
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
1843+
//shmem_q[(4*chpt)*(tiisg/16 ) + tiisg%16] = xq0[tiisg/16 ][16*iib + tiisg%16];
1844+
//shmem_q[(4*chpt)*(tiisg/16 + 2) + tiisg%16] = xq0[tiisg/16 + 2][16*iib + tiisg%16];
1845+
//shmem_q[(4*chpt)*(tiisg/16 + 4) + tiisg%16] = xq0[tiisg/16 + 4][16*iib + tiisg%16];
1846+
//shmem_q[(4*chpt)*(tiisg/16 + 6) + tiisg%16] = xq0[tiisg/16 + 6][16*iib + tiisg%16];
1847+
//shmem_q[(4*chpt)*2 + tiisg] = xq0[2][32*iib + tiisg];
1848+
//shmem_q[(4*chpt)*3 + tiisg] = xq0[3][32*iib + tiisg];
1849+
1850+
shmem_q[((4*chpt))*(tiisg/32 ) + tiisg%32] = xq0[tiisg/32 ][32*iib + tiisg%32];
1851+
shmem_q[((4*chpt))*(tiisg/32 + 1) + tiisg%32] = xq0[tiisg/32 + 1][32*iib + tiisg%32];
1852+
shmem_q[((4*chpt))*(tiisg/32 + 2) + tiisg%32] = xq0[tiisg/32 + 2][32*iib + tiisg%32];
1853+
shmem_q[((4*chpt))*(tiisg/32 + 3) + tiisg%32] = xq0[tiisg/32 + 3][32*iib + tiisg%32];
1854+
1855+
//if (chpt == 2) {
1856+
// shmem_q[(4*chpt)*(tiisg/8 ) + tiisg%8] = xq0[tiisg/8 ][8*iib + tiisg%8];
1857+
//}
1858+
1859+
simdgroup_barrier(mem_flags::mem_threadgroup);
1860+
18171861
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1818-
#pragma unroll(4)
1862+
//const float d = xq[ir0]->d;
1863+
//device const int8_t * qs = ((device const int8_t *) xq[ir0]->qs);
1864+
1865+
// float d[chpt];
1866+
// device const int8_t * qs[chpt];
1867+
//#pragma unroll(chpt)
1868+
// for (short ch = 0; ch < chpt; ++ch) {
1869+
// device const block_q8_0 * xc = xq[ir0] + (ch*nxpsg)/8;
1870+
// d[ch] = xc->d;
1871+
// qs[ch] = xc->qs;
1872+
// }
1873+
#pragma unroll(chpt)
18191874
for (short ch = 0; ch < chpt; ++ch) {
18201875
float4 lx;
18211876

1822-
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
1877+
//float4 lx = dequantize_q8_0x<float4>(qs, d, (chpt*tx + ch)%8);
1878+
//dequantize_q8_0x(xq[ir0] + ch/8, (chpt*tx + ch)%8, lx);
1879+
//float4 lx = dequantize_q8_0x<float4>(qs, d, (tx)%8);
1880+
//float4 lx = dequantize_q8_0x<float4>(qs[ch], d[ch], (tx)%8);
1881+
//dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
1882+
1883+
//dequantize_q8_0x(xq0[ir0] + 8*iib + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
1884+
dequantize_q8_0s(shmem_q + (((4*chpt)*nxpsg)/32)*ir0 + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
1885+
//dequantize_q8_0s(shmem_q + 8*ir0 + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
18231886

1887+
//sumf[ir0] += dot(lx, y4[ch]);
18241888
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
18251889
}
18261890
}
18271891

18281892
y4 += ((4*chpt)*nxpsg)/4;
18291893

1830-
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1831-
xq[ir0] += ((4*chpt)*nxpsg)/32;
1832-
}
1894+
//for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1895+
// xq[ir0] += ((4*chpt)*nxpsg)/32;
1896+
//}
18331897
}
18341898

18351899
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
@@ -1867,31 +1931,60 @@ kernel void kernel_mul_mv_ext_q8_0_f32(
18671931
device const char * src0,
18681932
device const char * src1,
18691933
device char * dst,
1934+
threadgroup char * shmem [[threadgroup(0)]],
18701935
uint3 tgpig[[threadgroup_position_in_grid]],
18711936
ushort3 ntg[[threads_per_threadgroup]],
18721937
ushort tiisg[[thread_index_in_simdgroup]],
18731938
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
18741939
switch (args.nsg) {
18751940
case 1:
18761941
switch (args.nxpsg) {
1877-
case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1878-
case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1879-
case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1880-
case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1942+
case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1943+
case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1944+
case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1945+
case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
18811946
} break;
18821947
case 2:
18831948
switch (args.nxpsg) {
1884-
case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1885-
case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1886-
case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1887-
case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1949+
case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1950+
case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1951+
case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1952+
case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
18881953
} break;
18891954
case 4:
18901955
switch (args.nxpsg) {
1891-
case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1892-
case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1893-
case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1894-
case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
1956+
case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1957+
case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1958+
case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1959+
case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1960+
} break;
1961+
case 6:
1962+
switch (args.nxpsg) {
1963+
case 4: kernel_mul_mv_ext_q8_0_f32_impl<6, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1964+
case 8: kernel_mul_mv_ext_q8_0_f32_impl<6, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1965+
case 16: kernel_mul_mv_ext_q8_0_f32_impl<6, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1966+
case 32: kernel_mul_mv_ext_q8_0_f32_impl<6, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1967+
} break;
1968+
case 8:
1969+
switch (args.nxpsg) {
1970+
case 4: kernel_mul_mv_ext_q8_0_f32_impl<8, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1971+
case 8: kernel_mul_mv_ext_q8_0_f32_impl<8, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1972+
case 16: kernel_mul_mv_ext_q8_0_f32_impl<8, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1973+
case 32: kernel_mul_mv_ext_q8_0_f32_impl<8, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1974+
} break;
1975+
case 12:
1976+
switch (args.nxpsg) {
1977+
case 4: kernel_mul_mv_ext_q8_0_f32_impl<12, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1978+
case 8: kernel_mul_mv_ext_q8_0_f32_impl<12, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1979+
case 16: kernel_mul_mv_ext_q8_0_f32_impl<12, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1980+
case 32: kernel_mul_mv_ext_q8_0_f32_impl<12, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1981+
} break;
1982+
case 16:
1983+
switch (args.nxpsg) {
1984+
case 4: kernel_mul_mv_ext_q8_0_f32_impl<16, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1985+
case 8: kernel_mul_mv_ext_q8_0_f32_impl<16, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1986+
case 16: kernel_mul_mv_ext_q8_0_f32_impl<16, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
1987+
case 32: kernel_mul_mv_ext_q8_0_f32_impl<16, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
18951988
} break;
18961989
}
18971990
}

0 commit comments

Comments
 (0)