@@ -165,6 +165,26 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
165165    reg = (type4x4) reg_f;
166166}
167167
168+ template  <typename  type4>
169+ void  dequantize_q4_0x (device const  block_q4_0 *xb, short  il, thread type4 & reg) {
170+     device const  int8_t  * qs = ((device const  int8_t  *)xb->qs );
171+     const  half d = xb->d ;
172+ 
173+     for  (int  i = 0 ; i < 4 ; i++) {
174+         reg[i] = qs[0 ];
175+     }
176+ }
177+ 
178+ template  <typename  type4>
179+ void  dequantize_q8_0x (device const  block_q8_0 *xb, short  il, thread type4 & reg) {
180+     device const  int8_t  * qs = ((device const  int8_t  *)xb->qs );
181+     const  half d = xb->d ;
182+ 
183+     for  (int  i = 0 ; i < 4 ; i++) {
184+         reg[i] = (qs[4 *(il%4 ) + i + 16 *(il/4 )] * d);
185+     }
186+ }
187+ 
168188template  <typename  type4x4>
169189void  dequantize_q2_K (device const  block_q2_K *xb, short  il, thread type4x4 & reg) {
170190    const  float  d = xb->d ;
@@ -1749,8 +1769,8 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17491769        ushort3   ntg[[threads_per_threadgroup]],
17501770        ushort  tiisg[[thread_index_in_simdgroup]],
17511771        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
1752-     const  short  chpt = 1 ;
1753-     const  short  r0pt = 2 ;
1772+     const  short  chpt = 4 ;
1773+     const  short  r0pt = 1 ;
17541774
17551775  // const short nxpsg = (32);
17561776    const  short  nypsg = (32 /nxpsg)*r0pt;
@@ -1771,36 +1791,31 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
17711791    device const  block_q8_0 * xq[r0pt];
17721792
17731793    for  (short  ir0 = 0 ; ir0 < r0pt; ++ir0) {
1774-         xq[ir0] = (i01 + ir0 < args.ne01 ) ? (device const  block_q8_0 *) (src0 + offset0 + ir0*args.nb01 ) + (chpt*tx)/2  : (device const  block_q8_0 *) src0;
1794+         // xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
1795+         xq[ir0] = (i01 + ir0 < args.ne01 ) ? (device const  block_q8_0 *) (src0 + offset0 + ir0*args.nb01 ) + (tx)/8  : (device const  block_q8_0 *) src0;
17751796    }
17761797
1777-     device const  float4x4 * y4x4 = (device const  float4x4 *) (src1 + offset1) + chpt*tx;
1798+     // device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
1799+     device const  float4 * y4 = (device const  float4 *) (src1 + offset1) + tx;
17781800
17791801    float  sumf[r0pt] = { [0  ... r0pt - 1 ] = 0 .0f  };
17801802
1781-     for  (int  iib = 0 ; (16 *chpt)*(iib*nxpsg + tx) < args.ne00 ; ++iib) {
1782-         float4x4 lx;
1783- 
1784- #pragma  unroll(2)
1803+     for  (int  iib = 0 ; (4 *chpt)*(iib*nxpsg + tx) < args.ne00 ; ++iib) {
17851804        for  (short  ir0 = 0 ; ir0 < r0pt; ++ir0) {
1786- #pragma  unroll
1805+ #pragma  unroll(4) 
17871806            for  (short  ch = 0 ; ch < chpt; ++ch) {
1788-                 dequantize_q8_0 (xq[ir0] + ch/ 2 , (chpt*tx + ch)% 2 , lx) ;
1807+                 float4 lx ;
17891808
1790-                 const  float4x4 ly = y4x4[ch] ;
1809+                 dequantize_q8_0x (xq[ir0] + (ch*nxpsg)/ 8 , (tx)% 8 , lx) ;
17911810
1792-                 sumf[ir0] +=
1793-                     dot (lx[0 ], ly[0 ]) +
1794-                     dot (lx[1 ], ly[1 ]) +
1795-                     dot (lx[2 ], ly[2 ]) +
1796-                     dot (lx[3 ], ly[3 ]);
1811+                 sumf[ir0] += dot (lx, y4[ch*nxpsg]);
17971812            }
17981813        }
17991814
1800-         y4x4  += ((16 *chpt)*nxpsg)/16 ;
1815+         y4  += ((4 *chpt)*nxpsg)/4 ;
18011816
18021817        for  (short  ir0 = 0 ; ir0 < r0pt; ++ir0) {
1803-             xq[ir0] += ((16 *chpt)*nxpsg)/32 ;
1818+             xq[ir0] += ((4 *chpt)*nxpsg)/32 ;
18041819        }
18051820    }
18061821
0 commit comments