@@ -7526,6 +7526,7 @@ kernel void kernel_mul_mm_id(
75267526 threadgroup char * shmem [[threadgroup(0 )]],
75277527 uint3 tgpig[[threadgroup_position_in_grid]],
75287528 ushort tiitg[[thread_index_in_threadgroup]],
7529+ ushort tiisg[[thread_index_in_simdgroup]],
75297530 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
75307531
75317532 threadgroup T * sa = (threadgroup T *)(shmem);
@@ -7631,36 +7632,36 @@ kernel void kernel_mul_mm_id(
76317632 }
76327633
76337634 threadgroup_barrier (mem_flags::mem_threadgroup);
7635+
76347636 threadgroup float * temp_str = ((threadgroup float *) shmem) \
76357637 + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
7638+
76367639 for (short i = 0 ; i < 8 ; i++) {
76377640 simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
76387641 }
76397642
76407643 threadgroup_barrier (mem_flags::mem_threadgroup);
76417644
7642- if (sgitg == 0 ) {
7643- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
7644- const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
7645+ for (int j = sgitg; j < n_cols; j += 4 ) {
7646+ const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
76457647
7646- const int ide = id % args.ne20 ;
7647- const int idt = id / args.ne20 ;
7648+ const int ide = id % args.ne20 ;
7649+ const int idt = id / args.ne20 ;
76487650
7649- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
7650- device float4 * D4 = (device float4 *) D;
7651+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
7652+ device float4 * D4 = (device float4 *) D;
76517653
7652- threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
7653- threadgroup float4 * C4 = (threadgroup float4 *) C;
7654+ threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M);
7655+ threadgroup float4 * C4 = (threadgroup float4 *) C;
76547656
7655- int i = 0 ;
7656- for (; i < n_rows/4 ; i++ ) {
7657- *(D4 + i) = *(C4 + i);
7658- }
7657+ int i = tiisg ;
7658+ for (; i < n_rows/4 ; i += 32 ) {
7659+ *(D4 + i) = *(C4 + i);
7660+ }
76597661
7660- i *= 4 ;
7661- for (; i < n_rows; i++) {
7662- *(D + i) = *(C + i);
7663- }
7662+ i = (4 *(n_rows/4 )) + tiisg;
7663+ for (; i < n_rows; i += 32 ) {
7664+ *(D + i) = *(C + i);
76647665 }
76657666 }
76667667}
0 commit comments