@@ -7856,6 +7856,8 @@ kernel void kernel_set_rows_f(
78567856 }
78577857}
78587858
7859+ constant bool FC_mul_mm_bounds_check [[function_constant(FC_MUL_MM + 0 )]];
7860+
78597861#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
78607862#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
78617863#define BLOCK_SIZE_K 32
@@ -7913,27 +7915,58 @@ kernel void kernel_mul_mm(
79137915 device const block_q * x = (device const block_q *)(src0
79147916 + args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
79157917
7918+ const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
7919+
79167920 device const U * y = (device const U *)(src1
79177921 + args.nb13 *i13
79187922 + args.nb12 *i12
79197923 + args.nb11 *(r1*BLOCK_SIZE_N + thread_col)
7920- + args.nb10 *(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)) );
7924+ + args.nb10 *iy );
79217925
79227926 for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += BLOCK_SIZE_K) {
79237927 // load data and store to threadgroup memory
7924- T4x4 temp_a;
7925- dequantize_func (x, il, temp_a);
7928+ if (is_same<T4x4, block_q>::value) {
7929+ // no need for dequantization
7930+ threadgroup_barrier (mem_flags::mem_threadgroup);
79267931
7927- threadgroup_barrier (mem_flags::mem_threadgroup);
7932+ if (FC_mul_mm_bounds_check) {
7933+ // bounds checks are required
7934+ #pragma unroll(16)
7935+ for (short i = 0 ; i < 16 ; i++) {
7936+ *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
7937+ + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
7938+ + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00 ? ((device T *) x)[16 *il + i] : 0 ;
7939+ }
7940+ } else {
7941+ // do not perform bounds checks
7942+ #pragma unroll(16)
7943+ for (short i = 0 ; i < 16 ; i++) {
7944+ *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
7945+ + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
7946+ + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = ((device T *) x)[i];
7947+ }
7948+ }
7949+ } else {
7950+ T4x4 temp_a;
7951+ dequantize_func (x, il, temp_a);
79287952
7929- #pragma unroll(16)
7930- for (short i = 0 ; i < 16 ; i++) {
7931- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
7932- + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
7933- + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
7953+ threadgroup_barrier (mem_flags::mem_threadgroup);
7954+
7955+ #pragma unroll(16)
7956+ for (short i = 0 ; i < 16 ; i++) {
7957+ *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
7958+ + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
7959+ + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
7960+ }
79347961 }
79357962
7936- *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)(*((device U2x4 *) y));
7963+ if (FC_mul_mm_bounds_check) {
7964+ for (short i = 0 ; i < 8 ; ++i) {
7965+ sb[32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? ((device U *) y)[i] : 0 ;
7966+ }
7967+ } else {
7968+ *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)(*((device U2x4 *) y));
7969+ }
79377970
79387971 il = (il + 2 < nl) ? il + 2 : il % 2 ;
79397972 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
0 commit comments