@@ -190,6 +190,27 @@ void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg)
190190 }
191191}
192192
193+ template <typename type4>
194+ void dequantize_q8_0s (threadgroup const block_q8_0 * xb, short il, thread type4 & reg) {
195+ threadgroup const int8_t * qs = ((threadgroup const int8_t *) xb->qs );
196+ const float d = xb->d ;
197+
198+ for (int i = 0 ; i < 4 ; i++) {
199+ reg[i] = (qs[4 *(il%4 ) + i + 16 *(il/4 )]*d);
200+ }
201+ }
202+
203+ // template <typename type4>
204+ // type4 dequantize_q8_0x(device const int8_t * qs, float d, short il) {
205+ // thread type4 reg;
206+ // for (int i = 0; i < 4; i++) {
207+ // reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
208+ // //reg[i] = qs[i/2];
209+ // }
210+ //
211+ // return reg;
212+ // }
213+
193214template <typename type4x4>
194215void dequantize_q2_K (device const block_q2_K *xb, short il, thread type4x4 & reg) {
195216 const float d = xb->d ;
@@ -1778,12 +1799,13 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17781799 device const char * src0,
17791800 device const char * src1,
17801801 device char * dst,
1802+ threadgroup char * shmem [[threadgroup(0 )]],
17811803 uint3 tgpig[[threadgroup_position_in_grid]],
17821804 ushort3 ntg[[threads_per_threadgroup]],
17831805 ushort tiisg[[thread_index_in_simdgroup]],
17841806 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1785- const short chpt = 4 ;
1786- const short r0pt = 1 ;
1807+ const short chpt = 8 ;
1808+ const short r0pt = 4 ;
17871809
17881810 // const short nxpsg = (32);
17891811 const short nypsg = (32 /nxpsg)*r0pt;
@@ -1802,34 +1824,76 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
18021824 const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13 ;
18031825
18041826 device const block_q8_0 * xq[r0pt];
1827+ device const block_q8_0 * xq0[r0pt];
18051828
18061829 for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
18071830 // xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
1808- xq[ir0] = (i01 + ir0 < args.ne01 ) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01 ) + (tx)/8 : (device const block_q8_0 *) src0;
1831+ // xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
1832+ xq0[ir0] = (i01 + ir0 < args.ne01 ) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01 ) : (device const block_q8_0 *) src0;
18091833 }
18101834
18111835 // device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
18121836 device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
18131837
18141838 float sumf[r0pt] = { [0 ... r0pt - 1 ] = 0 .0f };
18151839
1840+ threadgroup block_q8_0 * shmem_q = (threadgroup block_q8_0 *) shmem + (((4 *chpt)*nxpsg)/32 )*r0pt*sgitg;
1841+
18161842 for (int iib = 0 ; (4 *chpt)*(iib*nxpsg + tx) < args.ne00 ; ++iib) {
1843+ // shmem_q[(4*chpt)*(tiisg/16 ) + tiisg%16] = xq0[tiisg/16 ][16*iib + tiisg%16];
1844+ // shmem_q[(4*chpt)*(tiisg/16 + 2) + tiisg%16] = xq0[tiisg/16 + 2][16*iib + tiisg%16];
1845+ // shmem_q[(4*chpt)*(tiisg/16 + 4) + tiisg%16] = xq0[tiisg/16 + 4][16*iib + tiisg%16];
1846+ // shmem_q[(4*chpt)*(tiisg/16 + 6) + tiisg%16] = xq0[tiisg/16 + 6][16*iib + tiisg%16];
1847+ // shmem_q[(4*chpt)*2 + tiisg] = xq0[2][32*iib + tiisg];
1848+ // shmem_q[(4*chpt)*3 + tiisg] = xq0[3][32*iib + tiisg];
1849+
1850+ shmem_q[((4 *chpt))*(tiisg/32 ) + tiisg%32 ] = xq0[tiisg/32 ][32 *iib + tiisg%32 ];
1851+ shmem_q[((4 *chpt))*(tiisg/32 + 1 ) + tiisg%32 ] = xq0[tiisg/32 + 1 ][32 *iib + tiisg%32 ];
1852+ shmem_q[((4 *chpt))*(tiisg/32 + 2 ) + tiisg%32 ] = xq0[tiisg/32 + 2 ][32 *iib + tiisg%32 ];
1853+ shmem_q[((4 *chpt))*(tiisg/32 + 3 ) + tiisg%32 ] = xq0[tiisg/32 + 3 ][32 *iib + tiisg%32 ];
1854+
1855+ // if (chpt == 2) {
1856+ // shmem_q[(4*chpt)*(tiisg/8 ) + tiisg%8] = xq0[tiisg/8 ][8*iib + tiisg%8];
1857+ // }
1858+
1859+ simdgroup_barrier (mem_flags::mem_threadgroup);
1860+
18171861 for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
1818- #pragma unroll(4)
1862+ // const float d = xq[ir0]->d;
1863+ // device const int8_t * qs = ((device const int8_t *) xq[ir0]->qs);
1864+
1865+ // float d[chpt];
1866+ // device const int8_t * qs[chpt];
1867+ // #pragma unroll(chpt)
1868+ // for (short ch = 0; ch < chpt; ++ch) {
1869+ // device const block_q8_0 * xc = xq[ir0] + (ch*nxpsg)/8;
1870+ // d[ch] = xc->d;
1871+ // qs[ch] = xc->qs;
1872+ // }
1873+ #pragma unroll(chpt)
18191874 for (short ch = 0 ; ch < chpt; ++ch) {
18201875 float4 lx;
18211876
1822- dequantize_q8_0x (xq[ir0] + (ch*nxpsg)/8 , (tx)%8 , lx);
1877+ // float4 lx = dequantize_q8_0x<float4>(qs, d, (chpt*tx + ch)%8);
1878+ // dequantize_q8_0x(xq[ir0] + ch/8, (chpt*tx + ch)%8, lx);
1879+ // float4 lx = dequantize_q8_0x<float4>(qs, d, (tx)%8);
1880+ // float4 lx = dequantize_q8_0x<float4>(qs[ch], d[ch], (tx)%8);
1881+ // dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
1882+
1883+ // dequantize_q8_0x(xq0[ir0] + 8*iib + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
1884+ dequantize_q8_0s (shmem_q + (((4 *chpt)*nxpsg)/32 )*ir0 + (ch*nxpsg)/8 + tx/8 , (tx)%8 , lx);
1885+ // dequantize_q8_0s(shmem_q + 8*ir0 + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
18231886
1887+ // sumf[ir0] += dot(lx, y4[ch]);
18241888 sumf[ir0] += dot (lx, y4[ch*nxpsg]);
18251889 }
18261890 }
18271891
18281892 y4 += ((4 *chpt)*nxpsg)/4 ;
18291893
1830- for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
1831- xq[ir0] += ((4 *chpt)*nxpsg)/32 ;
1832- }
1894+ // for (short ir0 = 0; ir0 < r0pt; ++ir0) {
1895+ // xq[ir0] += ((4*chpt)*nxpsg)/32;
1896+ // }
18331897 }
18341898
18351899 for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
@@ -1867,31 +1931,60 @@ kernel void kernel_mul_mv_ext_q8_0_f32(
18671931 device const char * src0,
18681932 device const char * src1,
18691933 device char * dst,
1934+ threadgroup char * shmem [[threadgroup(0 )]],
18701935 uint3 tgpig[[threadgroup_position_in_grid]],
18711936 ushort3 ntg[[threads_per_threadgroup]],
18721937 ushort tiisg[[thread_index_in_simdgroup]],
18731938 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
18741939 switch (args.nsg ) {
18751940 case 1 :
18761941 switch (args.nxpsg ) {
1877- case 4 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 4 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1878- case 8 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 8 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1879- case 16 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 16 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1880- case 32 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 32 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1942+ case 4 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 4 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1943+ case 8 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 8 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1944+ case 16 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 16 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1945+ case 32 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 32 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
18811946 } break ;
18821947 case 2 :
18831948 switch (args.nxpsg ) {
1884- case 4 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 4 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1885- case 8 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 8 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1886- case 16 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 16 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1887- case 32 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 32 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1949+ case 4 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 4 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1950+ case 8 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 8 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1951+ case 16 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 16 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1952+ case 32 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 32 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
18881953 } break ;
18891954 case 4 :
18901955 switch (args.nxpsg ) {
1891- case 4 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 4 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1892- case 8 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 8 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1893- case 16 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 16 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1894- case 32 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 32 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1956+ case 4 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 4 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1957+ case 8 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 8 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1958+ case 16 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 16 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1959+ case 32 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 32 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1960+ } break ;
1961+ case 6 :
1962+ switch (args.nxpsg ) {
1963+ case 4 : kernel_mul_mv_ext_q8_0_f32_impl<6 , 4 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1964+ case 8 : kernel_mul_mv_ext_q8_0_f32_impl<6 , 8 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1965+ case 16 : kernel_mul_mv_ext_q8_0_f32_impl<6 , 16 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1966+ case 32 : kernel_mul_mv_ext_q8_0_f32_impl<6 , 32 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1967+ } break ;
1968+ case 8 :
1969+ switch (args.nxpsg ) {
1970+ case 4 : kernel_mul_mv_ext_q8_0_f32_impl<8 , 4 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1971+ case 8 : kernel_mul_mv_ext_q8_0_f32_impl<8 , 8 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1972+ case 16 : kernel_mul_mv_ext_q8_0_f32_impl<8 , 16 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1973+ case 32 : kernel_mul_mv_ext_q8_0_f32_impl<8 , 32 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1974+ } break ;
1975+ case 12 :
1976+ switch (args.nxpsg ) {
1977+ case 4 : kernel_mul_mv_ext_q8_0_f32_impl<12 , 4 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1978+ case 8 : kernel_mul_mv_ext_q8_0_f32_impl<12 , 8 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1979+ case 16 : kernel_mul_mv_ext_q8_0_f32_impl<12 , 16 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1980+ case 32 : kernel_mul_mv_ext_q8_0_f32_impl<12 , 32 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1981+ } break ;
1982+ case 16 :
1983+ switch (args.nxpsg ) {
1984+ case 4 : kernel_mul_mv_ext_q8_0_f32_impl<16 , 4 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1985+ case 8 : kernel_mul_mv_ext_q8_0_f32_impl<16 , 8 > (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1986+ case 16 : kernel_mul_mv_ext_q8_0_f32_impl<16 , 16 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
1987+ case 32 : kernel_mul_mv_ext_q8_0_f32_impl<16 , 32 >(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break ;
18951988 } break ;
18961989 }
18971990}
0 commit comments