@@ -1739,6 +1739,135 @@ kernel void kernel_mul_mv_q8_0_f32(
17391739    kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr , tgpig, tiisg, sgitg);
17401740}
17411741
1742+ template <short  nsg, short  nxpsg>
1743+ void  kernel_mul_mv_ext_q8_0_f32_impl (
1744+         constant ggml_metal_kargs_mul_mv_ext & args,
1745+         device const  char  * src0,
1746+         device const  char  * src1,
1747+         device       char  * dst,
1748+         uint3   tgpig[[threadgroup_position_in_grid]],
1749+         ushort3   ntg[[threads_per_threadgroup]],
1750+         ushort  tiisg[[thread_index_in_simdgroup]],
1751+         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
1752+     const  short  chpt = 1 ;
1753+     const  short  r0pt = 2 ;
1754+ 
1755+   // const short nxpsg = (32);
1756+     const  short  nypsg = (32 /nxpsg)*r0pt;
1757+ 
1758+     const  short  tx = tiisg%nxpsg;
1759+     const  short  ty = tiisg/nxpsg;
1760+ 
1761+     const  int  i01 = tgpig.x *(nypsg*nsg) + nypsg*sgitg + ty*r0pt;
1762+     const  int  i11 = tgpig.y ;
1763+     const  int  i1m = tgpig.z ;
1764+ 
1765+     const  int  i12 = i1m%args.ne12 ;
1766+     const  int  i13 = i1m/args.ne12 ;
1767+ 
1768+     const  uint64_t  offset0 = i01*args.nb01  + (i12/args.r2 )*args.nb02  + (i13/args.r3 )*args.nb03 ;
1769+     const  uint64_t  offset1 = i11*args.nb11  + (i12        )*args.nb12  + (i13        )*args.nb13 ;
1770+ 
1771+     device const  block_q8_0 * xq[r0pt];
1772+ 
1773+     for  (short  ir0 = 0 ; ir0 < r0pt; ++ir0) {
1774+         xq[ir0] = (i01 + ir0 < args.ne01 ) ? (device const  block_q8_0 *) (src0 + offset0 + ir0*args.nb01 ) + (chpt*tx)/2  : (device const  block_q8_0 *) src0;
1775+     }
1776+ 
1777+     device const  float4x4 * y4x4 = (device const  float4x4 *) (src1 + offset1) + chpt*tx;
1778+ 
1779+     float  sumf[r0pt] = { [0  ... r0pt - 1 ] = 0 .0f  };
1780+ 
1781+     for  (int  iib = 0 ; (16 *chpt)*(iib*nxpsg + tx) < args.ne00 ; ++iib) {
1782+         float4x4 lx;
1783+ 
1784+ #pragma  unroll(2)
1785+         for  (short  ir0 = 0 ; ir0 < r0pt; ++ir0) {
1786+ #pragma  unroll
1787+             for  (short  ch = 0 ; ch < chpt; ++ch) {
1788+                 dequantize_q8_0 (xq[ir0] + ch/2 , (chpt*tx + ch)%2 , lx);
1789+ 
1790+                 const  float4x4 ly = y4x4[ch];
1791+ 
1792+                 sumf[ir0] +=
1793+                     dot (lx[0 ], ly[0 ]) +
1794+                     dot (lx[1 ], ly[1 ]) +
1795+                     dot (lx[2 ], ly[2 ]) +
1796+                     dot (lx[3 ], ly[3 ]);
1797+             }
1798+         }
1799+ 
1800+         y4x4 += ((16 *chpt)*nxpsg)/16 ;
1801+ 
1802+         for  (short  ir0 = 0 ; ir0 < r0pt; ++ir0) {
1803+             xq[ir0] += ((16 *chpt)*nxpsg)/32 ;
1804+         }
1805+     }
1806+ 
1807+     for  (short  ir0 = 0 ; ir0 < r0pt; ++ir0) {
1808+         if  (nxpsg >= 32 ) {
1809+             sumf[ir0] += simd_shuffle_down (sumf[ir0],  16 );
1810+         }
1811+         if  (nxpsg >= 16 ) {
1812+             sumf[ir0] += simd_shuffle_down (sumf[ir0],  8 );
1813+         }
1814+         if  (nxpsg >= 8 ) {
1815+             sumf[ir0] += simd_shuffle_down (sumf[ir0],  4 );
1816+         }
1817+         if  (nxpsg >= 4 ) {
1818+             sumf[ir0] += simd_shuffle_down (sumf[ir0],  2 );
1819+         }
1820+         if  (nxpsg >= 2 ) {
1821+             sumf[ir0] += simd_shuffle_down (sumf[ir0],  1 );
1822+         }
1823+ 
1824+         // sumf[ir0] = simd_sum(sumf[ir0]);
1825+     }
1826+ 
1827+     device float  * dst_f32 = (device float  *) dst + (uint64_t )i1m*args.ne0 *args.ne1  + (uint64_t )i11*args.ne0 ;
1828+ 
1829+     if  (tx == 0 ) {
1830+         for  (short  ir0 = 0 ; ir0 < r0pt && i01 + ir0 < args.ne01 ; ++ir0) {
1831+             dst_f32[i01 + ir0] = sumf[ir0];
1832+         }
1833+     }
1834+ }
1835+ 
1836+ [[host_name(" kernel_mul_mv_ext_q8_0_f32"  )]]
1837+ kernel void  kernel_mul_mv_ext_q8_0_f32 (
1838+         constant ggml_metal_kargs_mul_mv_ext & args,
1839+         device const  char  * src0,
1840+         device const  char  * src1,
1841+         device       char  * dst,
1842+         uint3   tgpig[[threadgroup_position_in_grid]],
1843+         ushort3   ntg[[threads_per_threadgroup]],
1844+         ushort  tiisg[[thread_index_in_simdgroup]],
1845+         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
1846+     switch  (args.nsg ) {
1847+         case  1 :
1848+             switch  (args.nxpsg ) {
1849+                 case  4 :  kernel_mul_mv_ext_q8_0_f32_impl<1 , 4 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1850+                 case  8 :  kernel_mul_mv_ext_q8_0_f32_impl<1 , 8 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1851+                 case  16 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 16 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1852+                 case  32 : kernel_mul_mv_ext_q8_0_f32_impl<1 , 32 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1853+             } break ;
1854+         case  2 :
1855+             switch  (args.nxpsg ) {
1856+                 case  4 :  kernel_mul_mv_ext_q8_0_f32_impl<2 , 4 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1857+                 case  8 :  kernel_mul_mv_ext_q8_0_f32_impl<2 , 8 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1858+                 case  16 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 16 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1859+                 case  32 : kernel_mul_mv_ext_q8_0_f32_impl<2 , 32 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1860+             } break ;
1861+         case  4 :
1862+             switch  (args.nxpsg ) {
1863+                 case  4 :  kernel_mul_mv_ext_q8_0_f32_impl<4 , 4 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1864+                 case  8 :  kernel_mul_mv_ext_q8_0_f32_impl<4 , 8 > (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1865+                 case  16 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 16 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1866+                 case  32 : kernel_mul_mv_ext_q8_0_f32_impl<4 , 32 >(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break ;
1867+             } break ;
1868+     }
1869+ }
1870+ 
17421871#define  N_MV_T_T  4 
17431872
17441873template <typename  T0, typename  T04, typename  T1, typename  T14, typename  args_t >
0 commit comments