@@ -7933,16 +7933,14 @@ kernel void kernel_mul_mm(
79337933 // no need for dequantization
79347934 if (FC_mul_mm_bounds_check) {
79357935 // bounds checks are required
7936- #pragma unroll(16)
79377936 for (short i = 0 ; i < 16 ; i++) {
79387937 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
79397938 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
79407939 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00 ? ((device T0 *) x)[i] : 0 ;
79417940 }
79427941 } else {
79437942 // do not perform bounds checks
7944- #pragma unroll(16)
7945- for (short i = 0 ; i < 16 ; i++) {
7943+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
79467944 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
79477945 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
79487946 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = ((device T0 *) x)[i];
@@ -7954,8 +7952,7 @@ kernel void kernel_mul_mm(
79547952
79557953 threadgroup_barrier (mem_flags::mem_threadgroup);
79567954
7957- #pragma unroll(16)
7958- for (short i = 0 ; i < 16 ; i++) {
7955+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
79597956 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
79607957 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
79617958 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
@@ -8188,16 +8185,14 @@ kernel void kernel_mul_mm_id(
81888185 // no need for dequantization
81898186 if (FC_mul_mm_bounds_check) {
81908187 // bounds checks are required
8191- #pragma unroll(16)
81928188 for (short i = 0 ; i < 16 ; i++) {
81938189 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
81948190 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
81958191 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00 ? ((device T0 *) x)[i] : 0 ;
81968192 }
81978193 } else {
81988194 // do not perform bounds checks
8199- #pragma unroll(16)
8200- for (short i = 0 ; i < 16 ; i++) {
8195+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
82018196 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
82028197 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
82038198 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = ((device T0 *) x)[i];
@@ -8209,8 +8204,7 @@ kernel void kernel_mul_mm_id(
82098204
82108205 threadgroup_barrier (mem_flags::mem_threadgroup);
82118206
8212- #pragma unroll(16)
8213- for (short i = 0 ; i < 16 ; i++) {
8207+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
82148208 *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
82158209 + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
82168210 + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
0 commit comments