@@ -6391,7 +6391,7 @@ kernel void kernel_get_rows_i32(
63916391#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
63926392#define SG_MAT_ROW 8
63936393
6394- // each block_q contains 16*nl weights
6394+ // Optimized matrix multiplication kernel with tiled shared memory access
63956395template <typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &)>
63966396kernel void kernel_mul_mm (
63976397 constant ggml_metal_kargs_mul_mm & args,
@@ -6403,134 +6403,86 @@ kernel void kernel_mul_mm(
64036403 ushort tiitg[[thread_index_in_threadgroup]],
64046404 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
64056405
6406- threadgroup T * sa = (threadgroup T *)(shmem);
6407- threadgroup float * sb = (threadgroup float *)(shmem + 4096 );
6408-
6409- const int r0 = tgpig.y ;
6410- const int r1 = tgpig.x ;
6406+ const int block_idx_m = tgpig.y ;
6407+ const int block_idx_n = tgpig.x ;
64116408 const int im = tgpig.z ;
64126409
6413- // if this block is of 64x32 shape or smaller
6414- const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
6415- const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
6410+ const int local_row = tiitg / THREAD_PER_ROW;
6411+ const int local_col = tiitg % THREAD_PER_ROW;
64166412
6417- // a thread shouldn't load data outside of the matrix
6418- const short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
6419- const short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
6413+ const int global_row = block_idx_m * BLOCK_SIZE_M + local_row;
6414+ const int global_col = block_idx_n * BLOCK_SIZE_N + local_col;
64206415
6421- simdgroup_T8x8 ma[4 ];
6422- simdgroup_float8x8 mb[2 ];
6423- simdgroup_float8x8 mc[8 ];
6416+ // Get matrix dimensions
6417+ const int64_t m = args.ne00 ; // Number of rows in A and C
6418+ const int64_t n = args.ne12 ; // Number of columns in B and C
6419+ const int64_t k = args.ne02 ; // Number of columns in A and rows in B
64246420
6425- for (short i = 0 ; i < 8 ; i++){
6426- mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6427- }
6428-
6429- short il = (tiitg % THREAD_PER_ROW);
6421+ // Cast shared memory to appropriate types
6422+ threadgroup T * sa = (threadgroup T *)(shmem);
6423+ threadgroup float * sb = (threadgroup float *)(shmem + 4096 );
64306424
6431- const int i12 = im%args. ne12 ;
6432- const int i13 = im/args. ne12 ;
6425+ // Initialize output tile
6426+ float acc = 0 . 0f ;
64336427
6428+ // Get batch indices
6429+ const int i12 = im % args.ne12 ;
6430+ const int i13 = im / args.ne12 ;
64346431 const uint64_t offset0 = (i12/args.r2 )*args.nb02 + (i13/args.r3 )*args.nb03 ;
6435- const short offset1 = il/nl;
6436-
6437- device const block_q * x = (device const block_q *)(src0
6438- + args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
6439-
6440- device const float * y = (device const float *)(src1
6441- + args.nb13 *i13
6442- + args.nb12 *i12
6443- + args.nb11 *(r1*BLOCK_SIZE_N + thread_col)
6444- + args.nb10 *(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
6445-
6446- for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += BLOCK_SIZE_K) {
6447- // load data and store to threadgroup memory
6448- T4x4 temp_a;
6449- dequantize_func (x, il, temp_a);
6450-
6451- threadgroup_barrier (mem_flags::mem_threadgroup);
64526432
6453- #pragma unroll(16)
6454- for (short i = 0 ; i < 16 ; i++) {
6455- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
6456- + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
6457- + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
6458- }
6459-
6460- *(threadgroup float2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
6461-
6462- il = (il + 2 < nl) ? il + 2 : il % 2 ;
6463- x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
6464- y += BLOCK_SIZE_K;
6465-
6466- threadgroup_barrier (mem_flags::mem_threadgroup);
6467-
6468- // load matrices from threadgroup memory and conduct outer products
6469- threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
6470- threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
6471-
6472- #pragma unroll(4)
6473- for (short ik = 0 ; ik < BLOCK_SIZE_K/8 ; ik++) {
6474- #pragma unroll(4)
6475- for (short i = 0 ; i < 4 ; i++) {
6476- simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i);
6477- }
6478-
6479- simdgroup_barrier (mem_flags::mem_none);
6480-
6481- #pragma unroll(2)
6482- for (short i = 0 ; i < 2 ; i++) {
6483- simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
6484- }
6485-
6486- #pragma unroll(8)
6487- for (short i = 0 ; i < 8 ; i++){
6488- simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
6433+ // Main computation loop over tiles
6434+ for (int tile_idx = 0 ; tile_idx < (k + BLOCK_SIZE_K - 1 ) / BLOCK_SIZE_K; ++tile_idx) {
6435+ const int tiled_k = tile_idx * BLOCK_SIZE_K + local_col;
6436+ const int tiled_m = tile_idx * BLOCK_SIZE_K + local_row;
6437+
6438+ // Load tile of A into shared memory with dequantization
6439+ if (global_row < m && tiled_k < k) {
6440+ const short il = tiitg % nl;
6441+ const short offset1 = il / nl;
6442+ device const block_q * x = (device const block_q *)(src0 + args.nb01 * global_row + offset0) + offset1;
6443+
6444+ T4x4 temp_a;
6445+ dequantize_func (x, il, temp_a);
6446+
6447+ // Store dequantized values to shared memory
6448+ #pragma unroll(16)
6449+ for (short i = 0 ; i < 16 ; i++) {
6450+ if (local_row * 16 + i < BLOCK_SIZE_M * BLOCK_SIZE_K) {
6451+ sa[local_row * BLOCK_SIZE_K + local_col] = temp_a[i/4 ][i%4 ];
6452+ }
64896453 }
6490-
6491- lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
6492- lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
6454+ } else {
6455+ sa[local_row * BLOCK_SIZE_K + local_col] = 0 ;
64936456 }
6494- }
6495-
6496- if ((r0 + 1 ) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= args.ne1 ) {
6497- device float * C = (device float *) dst +
6498- (BLOCK_SIZE_M * r0 + 32 *(sgitg & 1 )) + \
6499- (BLOCK_SIZE_N * r1 + 16 *(sgitg >> 1 )) * args.ne0 + im*args.ne1 *args.ne0 ;
65006457
6501- for ( short i = 0 ; i < 8 ; i++) {
6502- simdgroup_store (mc[i], C + 8 * (i% 4 ) + 8 * args. ne0 * (i/ 4 ), args. ne0 );
6503- }
6504- } else {
6505- // block is smaller than 64x32, we should avoid writing data outside of the matrix
6506- threadgroup_barrier (mem_flags::mem_threadgroup);
6507- threadgroup float * temp_str = ((threadgroup float *) shmem) \
6508- + 32 *(sgitg& 1 ) + ( 16 *(sgitg >> 1 ))*BLOCK_SIZE_M ;
6509- for ( short i = 0 ; i < 8 ; i++) {
6510- simdgroup_store (mc[i], temp_str + 8 *(i% 4 ) + 8 *BLOCK_SIZE_M*(i/ 4 ), BLOCK_SIZE_M) ;
6458+ // Load tile of B into shared memory
6459+ if (tiled_k < k && global_col < n) {
6460+ device const float * y = (device const float *)(src1
6461+ + args. nb13 * i13
6462+ + args. nb12 * i12
6463+ + args. nb11 * global_col
6464+ + args. nb10 * tiled_k);
6465+ sb[local_row * BLOCK_SIZE_N + local_col] = *y ;
6466+ } else {
6467+ sb[local_row * BLOCK_SIZE_N + local_col] = 0 ;
65116468 }
65126469
65136470 threadgroup_barrier (mem_flags::mem_threadgroup);
65146471
6515- if (sgitg == 0 ) {
6516- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6517- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1 *args.ne0 ;
6518- device float4 * D4 = (device float4 *) D;
6519-
6520- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
6521- threadgroup float4 * C4 = (threadgroup float4 *) C;
6472+ // Compute partial dot product
6473+ for (int i = 0 ; i < BLOCK_SIZE_K; ++i) {
6474+ acc += (float )sa[local_row * BLOCK_SIZE_K + i] *
6475+ (float )sb[i * BLOCK_SIZE_N + local_col];
6476+ }
65226477
6523- int i = 0 ;
6524- for (; i < n_rows/4 ; i++) {
6525- *(D4 + i) = *(C4 + i);
6526- }
6478+ threadgroup_barrier (mem_flags::mem_threadgroup);
6479+ }
65276480
6528- i *= 4 ;
6529- for (; i < n_rows; i++) {
6530- *(D + i) = *(C + i);
6531- }
6532- }
6533- }
6481+ // Store the result with bounds checking
6482+ if (global_row < m && global_col < n) {
6483+ device float * C = (device float *)dst + global_row * n + global_col +
6484+ im * args.ne1 * args.ne0 ;
6485+ *C = acc;
65346486 }
65356487}
65366488
0 commit comments