@@ -1784,6 +1784,7 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17841784 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
17851785 const short chpt = 4 ;
17861786 const short r0pt = 1 ;
1787+ const short r1pt = 4 ;
17871788
17881789 // const short nxpsg = (32);
17891790 const short nypsg = (32 /nxpsg)*r0pt;
@@ -1792,7 +1793,7 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17921793 const short ty = tiisg/nxpsg;
17931794
17941795 const int i01 = tgpig.x *(nypsg*nsg) + nypsg*sgitg + ty*r0pt;
1795- const int i11 = tgpig.y ;
1796+ const int i11 = tgpig.y *r1pt ;
17961797 const int i1m = tgpig.z ;
17971798
17981799 const int i12 = i1m%args.ne12 ;
@@ -1801,17 +1802,24 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
18011802 const uint64_t offset0 = i01*args.nb01 + (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
18021803 const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13 ;
18031804
1805+ // device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
1806+ // device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
1807+
18041808 device const block_q8_0 * xq[r0pt];
18051809
18061810 for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
18071811 // xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
18081812 xq[ir0] = (i01 + ir0 < args.ne01 ) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01 ) + (tx)/8 : (device const block_q8_0 *) src0;
18091813 }
18101814
1811- // device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
1812- device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
1815+ device const float4 * y4[r1pt];
1816+ for (int ir1 = 0 ; ir1 < r1pt; ++ir1) {
1817+ // y4[ir1] = (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx;
1818+ y4[ir1] = (i11 + ir1 < args.ne11 ) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11 ) + tx : (device const float4 *) src1;
1819+ }
18131820
1814- float sumf[r0pt] = { [0 ... r0pt - 1 ] = 0 .0f };
1821+ // float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
1822+ float sumf[r1pt][r0pt] = { [ 0 ... r1pt - 1 ] = { [0 ... r0pt - 1 ] = 0 .0f } };
18151823
18161824 for (int iib = 0 ; (4 *chpt)*(iib*nxpsg + tx) < args.ne00 ; ++iib) {
18171825 for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
@@ -1821,42 +1829,53 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
18211829
18221830 dequantize_q8_0x (xq[ir0] + (ch*nxpsg)/8 , (tx)%8 , lx);
18231831
1824- sumf[ir0] += dot (lx, y4[ch*nxpsg]);
1832+ #pragma unroll(4)
1833+ for (short ir1 = 0 ; ir1 < r1pt; ++ir1) {
1834+ sumf[ir1][ir0] += dot (lx, y4[ir1][ch*nxpsg]);
1835+ }
18251836 }
18261837 }
18271838
1828- y4 += ((4 *chpt)*nxpsg)/4 ;
1839+ for (short ir1 = 0 ; ir1 < r1pt; ++ir1) {
1840+ y4[ir1] += ((4 *chpt)*nxpsg)/4 ;
1841+ }
18291842
18301843 for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
18311844 xq[ir0] += ((4 *chpt)*nxpsg)/32 ;
18321845 }
18331846 }
18341847
1835- for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
1836- if (nxpsg >= 32 ) {
1837- sumf[ir0] += simd_shuffle_down (sumf[ir0], 16 );
1838- }
1839- if (nxpsg >= 16 ) {
1840- sumf[ir0] += simd_shuffle_down (sumf[ir0], 8 );
1841- }
1842- if (nxpsg >= 8 ) {
1843- sumf[ir0] += simd_shuffle_down (sumf[ir0], 4 );
1844- }
1845- if (nxpsg >= 4 ) {
1846- sumf[ir0] += simd_shuffle_down (sumf[ir0], 2 );
1847- }
1848- if (nxpsg >= 2 ) {
1849- sumf[ir0] += simd_shuffle_down (sumf[ir0], 1 );
1850- }
1848+ for (short ir1 = 0 ; ir1 < r1pt; ++ir1) {
1849+ for (short ir0 = 0 ; ir0 < r0pt; ++ir0) {
1850+ if (nxpsg >= 32 ) {
1851+ sumf[ir1][ir0] += simd_shuffle_down (sumf[ir1][ir0], 16 );
1852+ }
1853+ if (nxpsg >= 16 ) {
1854+ sumf[ir1][ir0] += simd_shuffle_down (sumf[ir1][ir0], 8 );
1855+ }
1856+ if (nxpsg >= 8 ) {
1857+ sumf[ir1][ir0] += simd_shuffle_down (sumf[ir1][ir0], 4 );
1858+ }
1859+ if (nxpsg >= 4 ) {
1860+ sumf[ir1][ir0] += simd_shuffle_down (sumf[ir1][ir0], 2 );
1861+ }
1862+ if (nxpsg >= 2 ) {
1863+ sumf[ir1][ir0] += simd_shuffle_down (sumf[ir1][ir0], 1 );
1864+ }
18511865
1852- // sumf[ir0] = simd_sum(sumf[ir0]);
1866+ // sumf[ir1][ir0] = simd_sum(sumf[ir1][ir0]);
1867+ }
18531868 }
18541869
1855- device float * dst_f32 = (device float *) dst + (uint64_t )i1m*args.ne0 *args.ne1 + (uint64_t )i11*args.ne0 ;
1870+ // device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
18561871
18571872 if (tx == 0 ) {
1858- for (short ir0 = 0 ; ir0 < r0pt && i01 + ir0 < args.ne01 ; ++ir0) {
1859- dst_f32[i01 + ir0] = sumf[ir0];
1873+ for (short ir1 = 0 ; ir1 < r1pt && i11 + ir1 < args.ne11 ; ++ir1) {
1874+ device float * dst_f32 = (device float *) dst + (uint64_t )i1m*args.ne0 *args.ne1 + (uint64_t )(i11 + ir1)*args.ne0 ;
1875+
1876+ for (short ir0 = 0 ; ir0 < r0pt && i01 + ir0 < args.ne01 ; ++ir0) {
1877+ dst_f32[i01 + ir0] = sumf[ir1][ir0];
1878+ }
18601879 }
18611880 }
18621881}
0 commit comments