@@ -8204,12 +8204,12 @@ kernel void kernel_mul_mm(
82048204        mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
82058205    }
82068206#else 
8207-     auto  tA = tensor<threadgroup S0,     dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK,  NR0));
8208-     auto  tB = tensor<threadgroup S1,     dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
8207+     auto  tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK,  NR0));
8208+     auto  tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
82098209
8210-     constexpr   auto  desc =  mpp::tensor_ops::matmul2d_descriptor  (NR1, NR0, NK,  false ,  true ,  false , mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); 
8211- 
8212-     mpp::tensor_ops::matmul2d<desc,  execution_simdgroups<4 >> mm;
8210+     mpp::tensor_ops::matmul2d< 
8211+          mpp::tensor_ops::matmul2d_descriptor (NR1, NR0, NK,  false ,  true ,  false , mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), 
8212+          execution_simdgroups<4 >> mm;
82138213
82148214    auto  cT = mm.get_destination_cooperative_tensor <decltype (tA), decltype (tB), float >();
82158215#endif 
@@ -8522,72 +8522,169 @@ kernel void kernel_mul_mm_id(
85228522        ushort tiitg[[thread_index_in_threadgroup]],
85238523        ushort tiisg[[thread_index_in_simdgroup]],
85248524        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8525- 
85268525    threadgroup S0 * sa = (threadgroup S0 *)(shmem);
85278526    threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096 );
85288527
8529-     const  int  r0 = tgpig.y ;
8530-     const  int  r1 = tgpig.x ;
8528+     threadgroup float  * sc = (threadgroup float  *)(shmem);
8529+ 
8530+     constexpr  int  NR0 = 64 ;
8531+     constexpr  int  NR1 = 32 ;
8532+ 
8533+     constexpr  int  NK  = 32 ;
8534+     constexpr  int  NL0 = NK/16 ;
8535+     constexpr  int  NL1 = NK/8 ;
8536+ 
85318537    const  int  im = tgpig.z ; //  expert
8538+     const  int  r0 = tgpig.y *NR0;
8539+     const  int  r1 = tgpig.x *NR1;
85328540
85338541    device const  uint32_t  * tpe_u32 = (device const  uint32_t  *) (htpe);
85348542    device const  int32_t   * ids_i32 = (device const  int32_t   *) (hids);
85358543
85368544    const  int32_t  neh1 = tpe_u32[im];
85378545
8538-     if  (r1*BLOCK_SIZE_N  >= neh1) {
8546+     if  (r1 >= neh1) {
85398547        return ;
85408548    }
85418549
85428550    //  if this block is of 64x32 shape or smaller
8543-     const  short  n_rows  = (args.ne0  - r0*BLOCK_SIZE_M  < BLOCK_SIZE_M ) ? (args.ne0  - r0*BLOCK_SIZE_M ) : BLOCK_SIZE_M ;
8544-     const  short  n_cols  = (    neh1 - r1*BLOCK_SIZE_N  < BLOCK_SIZE_N ) ? (    neh1 - r1*BLOCK_SIZE_N ) : BLOCK_SIZE_N ;
8551+     const  short  nr0  = (args.ne0  - r0 < NR0 ) ? (args.ne0  - r0) : NR0 ;
8552+     const  short  nr1  = (    neh1 - r1 < NR1 ) ? (    neh1 - r1) : NR1 ;
85458553
85468554    //  a thread shouldn't load data outside of the matrix
8547-     const  short  thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
8548-     const  short  thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
8549- 
8550-     S0_8x8 ma[4 ];
8551-     S1_8x8 mb[2 ];
8552- 
8553-     simdgroup_float8x8 mc[8 ];
8555+     const  short  lr0 = ((short )tiitg/NL0) < nr0 ? ((short )tiitg/NL0) : nr0 - 1 ; //  0 .. 63
8556+     const  short  lr1 = ((short )tiitg/NL1) < nr1 ? ((short )tiitg/NL1) : nr1 - 1 ; //  0 .. 31
85548557
8555-     for  (short  i = 0 ; i < 8 ; i++){
8556-         mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8557-     }
8558+     const  short  il0 = (tiitg % NL0);
85588559
8559-     short  il = (tiitg % THREAD_PER_ROW) ;
8560+     short  il = il0 ;
85608561
8561-     const  int  id = ids_i32[im*args.ne21  + r1*BLOCK_SIZE_N  + thread_col ];
8562+     const  int  id = ids_i32[im*args.ne21  + r1 + lr1 ];
85628563
85638564    const  short  i11 = (id % args.ne20 ) % args.ne11 ;
85648565    const  short  i12 = (id / args.ne20 );
85658566    const  short  i13 = 0 ;
85668567
85678568    const  uint64_t  offset0 = im*args.nb02  + i13*args.nb03 ;
8568-     const  short     offset1 = il /nl;
8569+     const  short     offset1 = il0 /nl;
85698570
8570-     device const  block_q * x = (device const  block_q *)(src0
8571-         + args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
8571+     device const  block_q * x = (device const  block_q *)(src0 + args.nb01 *(r0 + lr0) + offset0) + offset1;
85728572
8573-     const  short  iy = (BLOCK_SIZE_K / THREAD_PER_COL *  (tiitg % THREAD_PER_COL) );
8573+     const  short  iy = 8 * (tiitg % NL1 );
85748574
85758575    device const  T1 * y = (device const  T1 *)(src1
85768576        + args.nb13 *i13
85778577        + args.nb12 *i12
85788578        + args.nb11 *i11
85798579        + args.nb10 *iy);
85808580
8581-     for  (int  loop_k = 0 ; loop_k < args.ne00 ; loop_k += BLOCK_SIZE_K) {
8581+ #ifndef  GGML_METAL_HAS_TENSOR
8582+     S0_8x8 ma[4 ];
8583+     S1_8x8 mb[2 ];
8584+ 
8585+     simdgroup_float8x8 mc[8 ];
8586+ 
8587+     for  (short  i = 0 ; i < 8 ; i++){
8588+         mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8589+     }
8590+ #else 
8591+     auto  tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK,  NR0));
8592+     auto  tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
8593+ 
8594+     mpp::tensor_ops::matmul2d<
8595+         mpp::tensor_ops::matmul2d_descriptor (NR1, NR0, NK, false , true , false , mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8596+         execution_simdgroups<4 >> mm;
8597+ 
8598+     auto  cT = mm.get_destination_cooperative_tensor <decltype (tA), decltype (tB), float >();
8599+ #endif 
8600+ 
8601+     for  (int  loop_k = 0 ; loop_k < args.ne00 ; loop_k += NK) {
8602+ #ifndef  GGML_METAL_HAS_TENSOR
8603+         //  load data and store to threadgroup memory
8604+         if  (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8605+             threadgroup_barrier (mem_flags::mem_threadgroup);
8606+ 
8607+             //  no need for dequantization
8608+             for  (short  i = 0 ; i < 16 ; i++) {
8609+                 const  short  sx = 2 *il0 + i/8 ;
8610+                 const  short  sy = (tiitg/NL0)/8 ;
8611+ 
8612+               // const short lx = i%8;
8613+               // const short ly = (tiitg/NL0)%8;
8614+                 const  short  lx = (tiitg/NL0)%8 ;
8615+                 const  short  ly = i%8 ;
8616+ 
8617+                 const  short  ib = 8 *sx + sy;
8618+ 
8619+                 *(sa + 64 *ib + 8 *ly + lx) = loop_k + 16 *il + i < args.ne00  ? *((device T0 *) x + i) : 0 ;
8620+             }
8621+         } else  {
8622+             S0_4x4 temp_a;
8623+             dequantize_func (x, il, temp_a);
8624+ 
8625+             threadgroup_barrier (mem_flags::mem_threadgroup);
8626+ 
8627+             FOR_UNROLL  (short  i = 0 ; i < 16 ; i++) {
8628+                 const  short  sx = 2 *il0 + i/8 ;
8629+                 const  short  sy = (tiitg/NL0)/8 ;
8630+ 
8631+               // const short lx = i%8;
8632+               // const short ly = (tiitg/NL0)%8;
8633+                 const  short  lx = (tiitg/NL0)%8 ;
8634+                 const  short  ly = i%8 ;
8635+ 
8636+                 const  short  ib = 8 *sx + sy;
8637+ 
8638+                 //  NOTE: this is massively slower.. WTF?
8639+                 // sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8640+ 
8641+                 *(sa + 64 *ib + 8 *ly + lx) = temp_a[i/4 ][i%4 ];
8642+             }
8643+         }
8644+ 
8645+         if  (FC_mul_mm_bc_inp) {
8646+             for  (short  i = 0 ; i < 8 ; ++i) {
8647+                 const  short  sx = (tiitg%NL1);
8648+                 const  short  sy = (tiitg/NL1)/8 ;
8649+ 
8650+                 const  short  lx = i;
8651+                 const  short  ly = (tiitg/NL1)%8 ;
8652+               // const short lx = (tiitg/NL1)%8;
8653+               // const short ly = i;
8654+ 
8655+                 const  short  ib = 4 *sx + sy;
8656+ 
8657+                 *(sb + 64 *ib + 8 *ly + lx) = loop_k + iy + i < args.ne00  ? (S1) *((device T1 *) y + i) : 0 ;
8658+             }
8659+         } else  {
8660+             const  short  sx = (tiitg%NL1);
8661+             const  short  sy = (tiitg/NL1)/8 ;
8662+ 
8663+             const  short  dx = sx;
8664+             const  short  dy = sy;
8665+ 
8666+             const  short  ly = (tiitg/NL1)%8 ;
8667+ 
8668+             const  short  ib = 4 *sx + sy;
8669+ 
8670+             *(threadgroup S1_2x4 *)(sb + 64 *ib + 8 *ly) = (S1_2x4)(*((device T1_2x4 *) y));
8671+         }
8672+ #else 
85828673        //  load data and store to threadgroup memory
85838674        if  (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
85848675            threadgroup_barrier (mem_flags::mem_threadgroup);
85858676
85868677            //  no need for dequantization
85878678            for  (short  i = 0 ; i < 16 ; i++) {
8588-                 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
8589-                 +                     (tiitg%THREAD_PER_ROW)*16  + (i/8 )*8 ) \
8590-                 +                     (tiitg/THREAD_PER_ROW)%8   + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00  ? ((device T0 *) x)[i] : 0 ;
8679+                 const  short  sx = 2 *il0 + i/8 ;
8680+                 const  short  sy = (tiitg/NL0)/8 ;
8681+ 
8682+                 const  short  lx = i%8 ;
8683+                 const  short  ly = (tiitg/NL0)%8 ;
8684+                 // const short lx = (tiitg/NL0)%8;
8685+                 // const short ly = i%8;
8686+ 
8687+                 *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + 16 *il + i < args.ne00  ? *((device T0 *) x + i) : 0 ;
85918688            }
85928689        } else  {
85938690            S0_4x4 temp_a;
@@ -8596,85 +8693,120 @@ kernel void kernel_mul_mm_id(
85968693            threadgroup_barrier (mem_flags::mem_threadgroup);
85978694
85988695            FOR_UNROLL  (short  i = 0 ; i < 16 ; i++) {
8599-                 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
8600-                 +                     (tiitg%THREAD_PER_ROW)*16  + (i/8 )*8 ) \
8601-                 +                     (tiitg/THREAD_PER_ROW)%8   + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
8696+                 const  short  sx = 2 *il0 + i/8 ;
8697+                 const  short  sy = (tiitg/NL0)/8 ;
8698+ 
8699+                 const  short  lx = i%8 ;
8700+                 const  short  ly = (tiitg/NL0)%8 ;
8701+                 // const short lx = (tiitg/NL0)%8;
8702+                 // const short ly = i%8;
8703+ 
8704+                 *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = temp_a[i/4 ][i%4 ];
86028705            }
86038706        }
86048707
86058708        if  (FC_mul_mm_bc_inp) {
86068709            for  (short  i = 0 ; i < 8 ; ++i) {
8607-                 sb[32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00  ? (S1) ((device T1 *) y)[i] : 0 ;
8710+                 const  short  sx = (tiitg%NL1);
8711+                 const  short  sy = (tiitg/NL1)/8 ;
8712+ 
8713+                 const  short  lx = i;
8714+                 const  short  ly = (tiitg/NL1)%8 ;
8715+                 // const short lx = (tiitg/NL1)%8;
8716+                 // const short ly = i;
8717+ 
8718+                 *(sb + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + iy + i < args.ne00  ? (S1) *((device T1 *) y + i) : 0 ;
86088719            }
86098720        } else  {
8610-             *(threadgroup S1_2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
8721+             const  short  sx = (tiitg%NL1);
8722+             const  short  sy = (tiitg/NL1)/8 ;
8723+ 
8724+             // const short lx = i;
8725+             const  short  ly = (tiitg/NL1)%8 ;
8726+             // const short lx = (tiitg/NL1)%8;
8727+             // const short ly = i;
8728+ 
8729+             *(threadgroup S1_2x4 *)(sb + NK*(8 *sy + ly) + 8 *sx) = (S1_2x4)(*((device T1_2x4 *) y));
86118730        }
8731+ #endif 
86128732
86138733        il = (il + 2  < nl) ? il + 2  : il % 2 ;
86148734        x  = (il < 2 ) ? x + (2  + nl - 1 )/nl : x;
8615-         y += BLOCK_SIZE_K;
8735+ 
8736+         y += NK;
86168737
86178738        threadgroup_barrier (mem_flags::mem_threadgroup);
86188739
8740+ #ifndef  GGML_METAL_HAS_TENSOR
86198741        //  load matrices from threadgroup memory and conduct outer products
8620-         threadgroup const  S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE *(sgitg%2 ));
8621-         threadgroup const  S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE *(sgitg/2 ));
8622- 
8623-         # pragma  unroll(4) 
8624-         for  ( short  ik =  0 ; ik < BLOCK_SIZE_K/ 8 ; ik++) { 
8625-             # pragma  unroll(4) 
8626-             for  (short  i = 0 ; i < 4 ; i++) {
8627-                 simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i );
8742+         threadgroup const  S0 * lsma = (sa + 4 * 64 *(sgitg%2 ));
8743+         threadgroup const  S1 * lsmb = (sb + 2 * 64 *(sgitg/2 ));
8744+ 
8745+         FOR_UNROLL  ( short  ik =  0 ; ik < NK/ 8 ; ik++) { 
8746+              simdgroup_barrier (mem_flags::mem_none); 
8747+ 
8748+             FOR_UNROLL  (short  i = 0 ; i < 4 ; i++) {
8749+                 simdgroup_load (ma[i], lsma + 64 *i,  8 ,  0 ,  false );
86288750            }
86298751
86308752            simdgroup_barrier (mem_flags::mem_none);
86318753
8632-             #pragma  unroll(2)
8633-             for  (short  i = 0 ; i < 2 ; i++) {
8634-                 simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
8754+             FOR_UNROLL  (short  i = 0 ; i < 2 ; i++) {
8755+                 simdgroup_load (mb[i], lsmb + 64 *i, 8 , 0 , false );
86358756            }
86368757
8637-             #pragma  unroll(8)
8638-             for  (short  i = 0 ; i < 8 ; i++){
8758+             simdgroup_barrier (mem_flags::mem_none);
8759+ 
8760+             FOR_UNROLL  (short  i = 0 ; i < 8 ; i++){
86398761                simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
86408762            }
86418763
8642-             lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE ;
8643-             lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE ;
8764+             lsma += 8 * 64 ;
8765+             lsmb += 4 * 64 ;
86448766        }
8767+ #else 
8768+         auto  sA  = tA.slice (0 , 0 );
8769+         auto  sB  = tB.slice (0 , 0 );
8770+ 
8771+         mm.run (sB , sA , cT);
8772+ #endif 
86458773    }
86468774
8775+     //  block is smaller than 64x32, we should avoid writing data outside of the matrix
86478776    threadgroup_barrier (mem_flags::mem_threadgroup);
86488777
8649-     threadgroup float  * temp_str = ((threadgroup float  *) shmem) \
8650-                                  + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
8778+ #ifdef  GGML_METAL_HAS_TENSOR
8779+     auto  tC = tensor<threadgroup float , dextents<int32_t , 2 >, tensor_inline>(sc, dextents<int32_t , 2 >(NR0, NR1));
8780+     cT.store (tC);
8781+ #else 
8782+     threadgroup float  * temp_str = ((threadgroup float  *) shmem) + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*NR0;
86518783
8652-     #pragma  unroll(8)
86538784    for  (short  i = 0 ; i < 8 ; i++) {
8654-         simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M *(i/4 ), BLOCK_SIZE_M );
8785+         simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *NR0 *(i/4 ), NR0,  0 ,  false );
86558786    }
8787+ #endif 
86568788
86578789    threadgroup_barrier (mem_flags::mem_threadgroup);
86588790
8659-     for  (short  j = sgitg; j < n_cols ; j += 4 ) {
8660-         const  int  id = ids_i32[im*args.ne21  + r1*BLOCK_SIZE_N  + j];
8791+     for  (short  j = sgitg; j < nr1 ; j += 4 ) {
8792+         const  int  id = ids_i32[im*args.ne21  + r1 + j];
86618793
86628794        const  short  ide = id % args.ne20 ;
86638795        const  short  idt = id / args.ne20 ;
86648796
8665-         device float   * D  = (device float   *) dst + (r0*BLOCK_SIZE_M)  + ide*args.ne0  + idt*args.ne1 *args.ne0 ;
8797+         device float   * D  = (device float   *) dst + r0  + ide*args.ne0  + idt*args.ne1 *args.ne0 ;
86668798        device float4 * D4 = (device float4 *) D;
86678799
8668-         threadgroup float   * C  = (threadgroup float   *) shmem + (j*BLOCK_SIZE_M) ;
8800+         threadgroup float   * C  = (threadgroup float   *) shmem + j*NR0 ;
86698801        threadgroup float4 * C4 = (threadgroup float4 *) C;
86708802
86718803        int  i = tiisg;
8672-         for  (; i < n_rows /4 ; i += 32 ) {
8804+         for  (; i < nr0 /4 ; i += 32 ) {
86738805            *(D4 + i) = *(C4 + i);
86748806        }
86758807
8676-         i = (4 *(n_rows /4 )) + tiisg;
8677-         for  (; i < n_rows ; i += 32 ) {
8808+         i = (4 *(nr0 /4 )) + tiisg;
8809+         for  (; i < nr0 ; i += 32 ) {
86788810            *(D + i) = *(C + i);
86798811        }
86808812    }
0 commit comments