@@ -9,6 +9,18 @@ __embed_ggml-common.h__
99
1010#include < metal_stdlib>
1111
12+ #define GGML_METAL_USE_METAL4
13+
14+ #ifdef GGML_METAL_USE_METAL4
15+ #include < metal_stdlib>
16+ #include < metal_tensor>
17+
18+ #include < MetalPerformancePrimitives/MetalPerformancePrimitives.h>
19+
20+ using namespace metal ;
21+ using namespace mpp ::tensor_ops;
22+ #endif
23+
1224using namespace metal ;
1325
1426#define MAX (x, y ) ((x) > (y) ? (x) : (y))
@@ -8145,6 +8157,8 @@ kernel void kernel_mul_mm(
81458157 threadgroup S0 * sa = (threadgroup S0 *)(shmem);
81468158 threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096 );
81478159
8160+ threadgroup float * sc = (threadgroup float *)(shmem);
8161+
81488162 constexpr int NR0 = 64 ;
81498163 constexpr int NR1 = 32 ;
81508164
@@ -8164,15 +8178,6 @@ kernel void kernel_mul_mm(
81648178 const short lr0 = ((short )tiitg/NL0) < nr0 ? ((short )tiitg/NL0) : nr0 - 1 ; // 0 .. 63
81658179 const short lr1 = ((short )tiitg/NL1) < nr1 ? ((short )tiitg/NL1) : nr1 - 1 ; // 0 .. 31
81668180
8167- S0_8x8 ma[4 ];
8168- S1_8x8 mb[2 ];
8169-
8170- simdgroup_float8x8 mc[8 ];
8171-
8172- for (short i = 0 ; i < 8 ; i++){
8173- mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8174- }
8175-
81768181 const short il0 = (tiitg % NL0);
81778182
81788183 short il = il0;
@@ -8193,7 +8198,28 @@ kernel void kernel_mul_mm(
81938198 + args.nb11 *(r1 + lr1)
81948199 + args.nb10 *iy);
81958200
8201+ #ifndef GGML_METAL_USE_METAL4
8202+ S0_8x8 ma[4 ];
8203+ S1_8x8 mb[2 ];
8204+
8205+ simdgroup_float8x8 mc[8 ];
8206+
8207+ for (short i = 0 ; i < 8 ; i++){
8208+ mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8209+ }
8210+ #else
8211+ auto tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK, NR0));
8212+ auto tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
8213+
8214+ constexpr auto desc = matmul2d_descriptor (NR1, NR0, NK, false , true , false , matmul2d_descriptor::mode::multiply_accumulate);
8215+
8216+ matmul2d<desc, execution_simdgroups<4 >> mm;
8217+
8218+ auto cT = mm.get_destination_cooperative_tensor <decltype (tA), decltype (tB), float >();
8219+ #endif
8220+
81968221 for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += NK) {
8222+ #ifndef GGML_METAL_USE_METAL4
81978223 // load data and store to threadgroup memory
81988224 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
81998225 threadgroup_barrier (mem_flags::mem_threadgroup);
@@ -8297,26 +8323,100 @@ kernel void kernel_mul_mm(
82978323 lsma += 8 *64 ;
82988324 lsmb += 4 *64 ;
82998325 }
8326+ #else
8327+ // load data and store to threadgroup memory
8328+ if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8329+ threadgroup_barrier (mem_flags::mem_threadgroup);
8330+
8331+ // no need for dequantization
8332+ for (short i = 0 ; i < 16 ; i++) {
8333+ const short sx = 2 *il0 + i/8 ;
8334+ const short sy = (tiitg/NL0)/8 ;
8335+
8336+ const short lx = i%8 ;
8337+ const short ly = (tiitg/NL0)%8 ;
8338+ // const short lx = (tiitg/NL0)%8;
8339+ // const short ly = i%8;
8340+
8341+ *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + 16 *il + i < args.ne00 ? *((device T0 *) x + i) : 0 ;
8342+ }
8343+ } else {
8344+ S0_4x4 temp_a;
8345+ dequantize_func (x, il, temp_a);
8346+
8347+ threadgroup_barrier (mem_flags::mem_threadgroup);
8348+
8349+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
8350+ const short sx = 2 *il0 + i/8 ;
8351+ const short sy = (tiitg/NL0)/8 ;
8352+
8353+ const short lx = i%8 ;
8354+ const short ly = (tiitg/NL0)%8 ;
8355+ // const short lx = (tiitg/NL0)%8;
8356+ // const short ly = i%8;
8357+
8358+ *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = temp_a[i/4 ][i%4 ];
8359+ }
8360+ }
8361+
8362+ for (short i = 0 ; i < 8 ; ++i) {
8363+ const short sx = (tiitg%NL1);
8364+ const short sy = (tiitg/NL1)/8 ;
8365+
8366+ const short lx = i;
8367+ const short ly = (tiitg/NL1)%8 ;
8368+ // const short lx = (tiitg/NL1)%8;
8369+ // const short ly = i;
8370+
8371+ *(sb + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0 ;
8372+ }
8373+
8374+ il = (il + 2 < nl) ? il + 2 : il % 2 ;
8375+ x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
8376+
8377+ y += NK;
8378+
8379+ threadgroup_barrier (mem_flags::mem_threadgroup);
8380+
8381+ auto sA = tA.slice (0 , 0 );
8382+ auto sB = tB.slice (0 , 0 );
8383+
8384+ mm.run (sB , sA , cT);
8385+ #endif
83008386 }
83018387
83028388 if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1 )) {
83038389 // if no bounds checks on the output are needed, we can directly write to device memory
8390+ #ifdef GGML_METAL_USE_METAL4
8391+ device float * C = (device float *) dst +
8392+ r0 + \
8393+ r1 * args.ne0 + im*args.ne1 *args.ne0 ;
8394+
8395+ auto tC = tensor<device float , dextents<int32_t , 2 >, tensor_inline>(C, dextents<int32_t , 2 >(args.ne0 , NR1));
8396+ cT.store (tC);
8397+ #else
83048398 device float * C = (device float *) dst +
83058399 (r0 + 32 *(sgitg & 1 )) + \
83068400 (r1 + 16 *(sgitg >> 1 )) * args.ne0 + im*args.ne1 *args.ne0 ;
83078401
83088402 for (short i = 0 ; i < 8 ; i++) {
8309- simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 , 0 , false );
8403+ simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 , 0 , false );
83108404 }
8405+ #endif
83118406 } else {
83128407 // block is smaller than 64x32, we should avoid writing data outside of the matrix
83138408 threadgroup_barrier (mem_flags::mem_threadgroup);
83148409
83158410 threadgroup float * temp_str = ((threadgroup float *) shmem) + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*NR0;
83168411
8412+ #ifdef GGML_METAL_USE_METAL4
8413+ auto tC = tensor<threadgroup float , dextents<int32_t , 2 >, tensor_inline>(sc, dextents<int32_t , 2 >(NR0, NR1));
8414+ cT.store (tC);
8415+ #else
83178416 for (short i = 0 ; i < 8 ; i++) {
83188417 simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *NR0*(i/4 ), NR0, 0 , false );
83198418 }
8419+ #endif
83208420
83218421 threadgroup_barrier (mem_flags::mem_threadgroup);
83228422
0 commit comments