@@ -5755,31 +5755,32 @@ kernel void kernel_mul_mm(
57555755}
57565756
57575757// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
5758+ // TODO: this kernel needs to be reimplemented from scratch for better performance
57585759template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &)>
57595760void kernel_mul_mm_id_impl (
5760- device const uchar * src0,
5761- device const uchar * src1,
5761+ int32_t ne00,
5762+ int32_t ne02,
5763+ uint64_t nb01,
5764+ uint64_t nb02,
5765+ int32_t ne11,
5766+ int32_t ne12,
5767+ uint64_t nb10,
5768+ uint64_t nb11,
5769+ uint64_t nb12,
5770+ int32_t ne0,
5771+ int32_t ne1,
5772+ int64_t ne0ne1,
5773+ device const char * src0,
5774+ device const char * src1,
57625775 threadgroup ushort2 * rowids,
5763- device float * dst,
5764- constant int64_t & ne00,
5765- constant int64_t & ne02,
5766- constant uint64_t & nb01,
5767- constant uint64_t & nb02,
5768- constant int64_t & ne11,
5769- constant int64_t & ne12,
5770- constant uint64_t & nb10,
5771- constant uint64_t & nb11,
5772- constant uint64_t & nb12,
5773- constant int64_t & ne0,
5774- int64_t ne1,
5775- int64_t ne0ne1,
5776- threadgroup uchar * shared_memory,
5777- uint3 tgpig[[threadgroup_position_in_grid]],
5778- uint tiitg[[thread_index_in_threadgroup]],
5779- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5780-
5781- threadgroup half * sa = (threadgroup half *)(shared_memory);
5782- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096 );
5776+ device char * dst,
5777+ threadgroup char * shmem,
5778+ uint3 tgpig[[threadgroup_position_in_grid]],
5779+ ushort tiitg[[thread_index_in_threadgroup]],
5780+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5781+
5782+ threadgroup half * sa = (threadgroup half *)(shmem);
5783+ threadgroup float * sb = (threadgroup float *)(shmem + 4096 );
57835784
57845785 const uint r0 = tgpig.y ;
57855786 const uint r1 = tgpig.x ;
@@ -5796,9 +5797,9 @@ void kernel_mul_mm_id_impl(
57965797
57975798 simdgroup_half8x8 ma[4 ];
57985799 simdgroup_float8x8 mb[2 ];
5799- simdgroup_float8x8 c_res [8 ];
5800+ simdgroup_float8x8 mc [8 ];
58005801 for (int i = 0 ; i < 8 ; i++){
5801- c_res [i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
5802+ mc [i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
58025803 }
58035804 short il = (tiitg % THREAD_PER_ROW);
58045805
@@ -5836,41 +5837,57 @@ void kernel_mul_mm_id_impl(
58365837 threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
58375838 threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
58385839
5840+ #pragma unroll(BLOCK_SIZE_K/8)
58395841 for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
5842+ #pragma unroll(4)
58405843 for (int i = 0 ; i < 4 ; i++) {
58415844 simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i);
58425845 }
58435846 simdgroup_barrier (mem_flags::mem_none);
5847+ #pragma unroll(2)
58445848 for (int i = 0 ; i < 2 ; i++) {
58455849 simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
58465850 }
58475851
58485852 lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
58495853 lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
58505854
5855+ #pragma unroll(8)
58515856 for (int i = 0 ; i < 8 ; i++){
5852- simdgroup_multiply_accumulate (c_res [i], mb[i/4 ], ma[i%4 ], c_res [i]);
5857+ simdgroup_multiply_accumulate (mc [i], mb[i/4 ], ma[i%4 ], mc [i]);
58535858 }
58545859 }
58555860 }
58565861
58575862 {
58585863 threadgroup_barrier (mem_flags::mem_threadgroup);
5859- threadgroup float * temp_str = ((threadgroup float *)shared_memory ) \
5864+ threadgroup float * temp_str = ((threadgroup float *) shmem ) \
58605865 + 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
58615866 for (int i = 0 ; i < 8 ; i++) {
5862- simdgroup_store (c_res [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
5867+ simdgroup_store (mc [i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
58635868 }
58645869
58655870 threadgroup_barrier (mem_flags::mem_threadgroup);
58665871
5867- device float * C = dst + (BLOCK_SIZE_M * r0);
58685872 if (sgitg == 0 ) {
58695873 for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
58705874 threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
5871- int joff = jid[0 ] * ne0 + jid[1 ] * ne0ne1;
5872- for (int i = 0 ; i < n_rows; i++) {
5873- *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
5875+ int64_t joff = jid[0 ]*ne0 + jid[1 ]*ne0ne1;
5876+
5877+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
5878+ device float4 * D4 = (device float4 *) D;
5879+
5880+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
5881+ threadgroup float4 * C4 = (threadgroup float4 *) C;
5882+
5883+ int i = 0 ;
5884+ for (; i < n_rows/4 ; i++) {
5885+ *(D4 + i) = *(C4 + i);
5886+ }
5887+
5888+ i *= 4 ;
5889+ for (; i < n_rows; i++) {
5890+ *(D + i) = *(C + i);
58745891 }
58755892 }
58765893 }
@@ -5879,48 +5896,34 @@ void kernel_mul_mm_id_impl(
58795896
58805897template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread half4x4 &)>
58815898kernel void kernel_mul_mm_id (
5882- device const uchar * src0s,
5883- device const uchar * src1,
5884- device float * dst,
5885- device const uchar * ids,
5886- constant int64_t & nei0,
5887- constant int64_t & nei1,
5888- constant uint64_t & nbi1,
5889- constant int64_t & ne00,
5890- constant int64_t & ne02,
5891- constant uint64_t & nb01,
5892- constant uint64_t & nb02,
5893- constant int64_t & ne11,
5894- constant int64_t & ne12,
5895- constant int64_t & ne13,
5896- constant uint64_t & nb10,
5897- constant uint64_t & nb11,
5898- constant uint64_t & nb12,
5899- constant int64_t & ne0,
5900- constant int64_t & ne1,
5901- constant uint64_t & nb1,
5902- threadgroup uchar * shared_memory [[threadgroup(0 )]],
5903- uint3 tgpig[[threadgroup_position_in_grid]],
5904- uint tiitg[[thread_index_in_threadgroup]],
5905- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5899+ constant ggml_metal_kargs_mul_mm_id & args,
5900+ device const char * src0s,
5901+ device const char * src1,
5902+ device char * dst,
5903+ device const char * ids,
5904+ threadgroup char * shmem [[threadgroup(0 )]],
5905+ uint3 tgpig[[threadgroup_position_in_grid]],
5906+ ushort tiitg[[thread_index_in_threadgroup]],
5907+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
59065908
59075909 const int32_t i02 = tgpig.z ;
5910+
59085911 tgpig.z = 0 ;
59095912
5910- device const uchar * src0 = src0s + i02*nb02;
5913+ device const char * src0 = src0s + i02*args. nb02 ;
59115914
59125915 // row indices
5913- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192 );
5916+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192 );
59145917
59155918 // TODO: parallelize this loop
59165919 int64_t _ne1 = 0 ;
5917- for (ushort ii1 = 0 ; ii1 < nei1; ii1++) {
5918- for (ushort ii0 = 0 ; ii0 < nei0; ii0++) {
5919- int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
5920+ for (ushort ii1 = 0 ; ii1 < args. nei1 ; ii1++) {
5921+ for (ushort ii0 = 0 ; ii0 < args. nei0 ; ii0++) {
5922+ int32_t id = ((device int32_t *) (ids + ii1*args. nbi1 ))[ii0];
59205923 if (id == i02) {
5921- // if (tiitg == 0) {
5924+ if (tiitg == 0 ) {
59225925 rowids[_ne1] = ushort2 (ii0, ii1);
5923- // }
5926+ }
59245927 _ne1++;
59255928 }
59265929 }
@@ -5929,23 +5932,23 @@ kernel void kernel_mul_mm_id(
59295932 threadgroup_barrier (mem_flags::mem_threadgroup);
59305933
59315934 kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5935+ args.ne00 ,
5936+ args.ne02 ,
5937+ args.nb01 ,
5938+ args.nb02 ,
5939+ args.ne11 ,
5940+ args.ne12 ,
5941+ args.nb10 ,
5942+ args.nb11 ,
5943+ args.nb12 ,
5944+ args.ne0 ,
5945+ _ne1,
5946+ (int64_t )args.ne0 *args.ne1 ,
59325947 src0,
59335948 src1,
59345949 rowids,
59355950 dst,
5936- ne00,
5937- ne02,
5938- nb01,
5939- nb02,
5940- ne11,
5941- ne12,
5942- nb10,
5943- nb11,
5944- nb12,
5945- ne0,
5946- _ne1,
5947- ne0*ne1,
5948- shared_memory,
5951+ shmem,
59495952 tgpig,
59505953 tiitg,
59515954 sgitg);
0 commit comments