|
| 1 | +#pragma OPENCL EXTENSION cl_khr_fp16 : enable |
| 2 | +#pragma OPENCL EXTENSION cl_khr_subgroups : enable |
| 3 | + |
| 4 | +#define LM_FIRST_256B 0 |
| 5 | +#define LM_SECOND_256B 64 |
| 6 | +#define LM_THIRD_256B 128 |
| 7 | +#define LM_FOURTH_256B 192 |
| 8 | + |
| 9 | + |
| 10 | +inline float16 mm_load_a( |
| 11 | + image1d_buffer_t matrix_A, |
| 12 | + uint subMatrixAStartInElements, |
| 13 | + int nb01, |
| 14 | + int line_stride_matrix_A_in_bytes |
| 15 | +) { |
| 16 | + __private float8 regA; |
| 17 | + size_t sub_block_id_m = get_local_id(0); |
| 18 | + |
| 19 | +#ifdef KQV |
| 20 | + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4); |
| 21 | +#else // KQ |
| 22 | + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4); |
| 23 | +#endif |
| 24 | + |
| 25 | + regA.s0123 = read_imagef(matrix_A, a_texCoord/4); |
| 26 | + regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4); |
| 27 | + |
| 28 | + return convert_float16(as_half16(regA)); |
| 29 | +} |
| 30 | + |
| 31 | +inline float4 alu_32( |
| 32 | + float16 regA, |
| 33 | + __local float4* matrix_B_vec |
| 34 | +) { |
| 35 | + |
| 36 | + __private float4 rC = 0; |
| 37 | + int i = get_sub_group_id() * 64; |
| 38 | + |
| 39 | + rC += regA.s0 * matrix_B_vec[i]; |
| 40 | + rC += regA.s1 * matrix_B_vec[i + 16]; |
| 41 | + rC += regA.s4 * matrix_B_vec[i + 1]; |
| 42 | + rC += regA.s5 * matrix_B_vec[i + 17]; |
| 43 | + rC += regA.s8 * matrix_B_vec[i + 2]; |
| 44 | + rC += regA.s9 * matrix_B_vec[i + 18]; |
| 45 | + rC += regA.sc * matrix_B_vec[i + 3]; |
| 46 | + rC += regA.sd * matrix_B_vec[i + 19]; |
| 47 | + |
| 48 | + i += 32; |
| 49 | + |
| 50 | + rC += regA.s2 * matrix_B_vec[i]; |
| 51 | + rC += regA.s3 * matrix_B_vec[i + 16]; |
| 52 | + rC += regA.s6 * matrix_B_vec[i + 1]; |
| 53 | + rC += regA.s7 * matrix_B_vec[i + 17]; |
| 54 | + rC += regA.sa * matrix_B_vec[i + 2]; |
| 55 | + rC += regA.sb * matrix_B_vec[i + 18]; |
| 56 | + rC += regA.se * matrix_B_vec[i + 3]; |
| 57 | + rC += regA.sf * matrix_B_vec[i + 19]; |
| 58 | + |
| 59 | + return rC; |
| 60 | +} |
| 61 | + |
| 62 | +inline float16 alu_16( |
| 63 | + float16 regA, |
| 64 | + __local float* matrix_B_local |
| 65 | +) { |
| 66 | + float16 out; |
| 67 | + __local float4* matrix_B_vec = (__local float4*)matrix_B_local; |
| 68 | + |
| 69 | + out.s0123 = alu_32(regA, matrix_B_vec); |
| 70 | + out.s4567 = alu_32(regA, matrix_B_vec + 4); |
| 71 | + out.s89ab = alu_32(regA, matrix_B_vec + 8); |
| 72 | + out.scdef = alu_32(regA, matrix_B_vec + 12); |
| 73 | + |
| 74 | + return out; |
| 75 | +} |
| 76 | + |
| 77 | +inline void mm_mad( |
| 78 | + __local float* matrix_B_local, |
| 79 | + float16 regA, |
| 80 | + float8 regB, |
| 81 | + uint b_localOffsetInWords, |
| 82 | + float16* regC0_ptr, |
| 83 | + float16* regC1_ptr |
| 84 | +) { |
| 85 | + int offset = b_localOffsetInWords + get_sub_group_id() * 256; |
| 86 | + |
| 87 | + matrix_B_local[offset + LM_FIRST_256B] = regB.s0; |
| 88 | + matrix_B_local[offset + LM_SECOND_256B] = regB.s1; |
| 89 | + matrix_B_local[offset + LM_THIRD_256B] = regB.s2; |
| 90 | + matrix_B_local[offset + LM_FOURTH_256B] = regB.s3; |
| 91 | + |
| 92 | + float16 add0 = alu_16(regA, matrix_B_local); |
| 93 | + *regC0_ptr += add0; |
| 94 | + |
| 95 | + matrix_B_local[offset + LM_FIRST_256B] = regB.s4; |
| 96 | + matrix_B_local[offset + LM_SECOND_256B] = regB.s5; |
| 97 | + matrix_B_local[offset + LM_THIRD_256B] = regB.s6; |
| 98 | + matrix_B_local[offset + LM_FOURTH_256B] = regB.s7; |
| 99 | + |
| 100 | + float16 add1 = alu_16(regA, matrix_B_local); |
| 101 | + *regC1_ptr += add1; |
| 102 | +} |
| 103 | + |
| 104 | +inline void mm_store_c_N( |
| 105 | + __write_only image1d_buffer_t matrix_C, |
| 106 | + float16 regC0, |
| 107 | + float16 regC1, |
| 108 | + uint subMatrixCStartInElements, |
| 109 | + int line_stride_matrix_C_in_bytes, |
| 110 | + int mask |
| 111 | +) { |
| 112 | + size_t sub_block_id_m = get_local_id(0); |
| 113 | + |
| 114 | + uint strideInWords = line_stride_matrix_C_in_bytes/4; |
| 115 | + uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m); |
| 116 | + |
| 117 | + uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords; |
| 118 | + uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords; |
| 119 | + uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords; |
| 120 | + uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords; |
| 121 | + uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords; |
| 122 | + uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords; |
| 123 | + uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords; |
| 124 | + uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords; |
| 125 | + uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords; |
| 126 | + uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords; |
| 127 | + uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords; |
| 128 | + uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords; |
| 129 | + uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords; |
| 130 | + uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords; |
| 131 | + uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords; |
| 132 | + uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords; |
| 133 | + uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords; |
| 134 | + uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords; |
| 135 | + uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords; |
| 136 | + uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords; |
| 137 | + uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords; |
| 138 | + uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords; |
| 139 | + uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords; |
| 140 | + uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords; |
| 141 | + uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords; |
| 142 | + uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords; |
| 143 | + uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords; |
| 144 | + uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords; |
| 145 | + uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords; |
| 146 | + uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords; |
| 147 | + uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords; |
| 148 | + |
| 149 | + if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); } |
| 150 | + if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); } |
| 151 | + if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); } |
| 152 | + if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); } |
| 153 | + if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); } |
| 154 | + if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); } |
| 155 | + if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); } |
| 156 | + if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); } |
| 157 | + if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); } |
| 158 | + if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); } |
| 159 | + if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); } |
| 160 | + if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); } |
| 161 | + if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); } |
| 162 | + if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); } |
| 163 | + if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); } |
| 164 | + if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); } |
| 165 | + if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); } |
| 166 | + if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); } |
| 167 | + if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); } |
| 168 | + if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); } |
| 169 | + if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); } |
| 170 | + if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); } |
| 171 | + if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); } |
| 172 | + if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); } |
| 173 | + if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); } |
| 174 | + if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); } |
| 175 | + if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); } |
| 176 | + if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); } |
| 177 | + if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); } |
| 178 | + if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); } |
| 179 | + if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); } |
| 180 | + if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); } |
| 181 | +} |
| 182 | + |
| 183 | +#define TILESIZE_K 16 |
| 184 | +#define TILESIZE_M 64 |
| 185 | +#define TILESIZE_N 32 |
| 186 | +#ifdef KQV |
| 187 | +__kernel void mul_mm_f16_f32_kqv( |
| 188 | +#else |
| 189 | +__kernel void mul_mm_f16_f32_kq( |
| 190 | +#endif |
| 191 | + __read_only image1d_buffer_t matrix_A, |
| 192 | + int offset0, |
| 193 | + __global float* matrix_B, |
| 194 | + int offset1, |
| 195 | + __write_only image1d_buffer_t matrix_C, |
| 196 | + int offsetd, |
| 197 | + int M, int K, int N, |
| 198 | + int D_A, |
| 199 | + int D_B, |
| 200 | + int nb01 |
| 201 | +) { |
| 202 | + |
| 203 | + uint block_id_m = get_global_id(1); |
| 204 | + uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N); |
| 205 | + uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N); |
| 206 | + |
| 207 | + __private float16 regA; |
| 208 | + __private float8 regB; |
| 209 | + __private float16 regC0; |
| 210 | + __private float16 regC1; |
| 211 | + |
| 212 | + const uint col = block_id_m * TILESIZE_M; |
| 213 | + const uint row = block_id_n * TILESIZE_N; |
| 214 | + const uint depth_A = block_id_d / (D_B/D_A); |
| 215 | + const uint depth_B = block_id_d; |
| 216 | + |
| 217 | +#ifdef KQV |
| 218 | + int line_stride_matrix_A_in_bytes = nb01 * M; |
| 219 | + int line_stride_matrix_B_in_bytes = K * N * 4; |
| 220 | +#else |
| 221 | + int line_stride_matrix_A_in_bytes = K * D_A * 2; |
| 222 | + int line_stride_matrix_B_in_bytes = K * D_B * 4; |
| 223 | +#endif |
| 224 | + |
| 225 | + int line_stride_matrix_C_in_bytes = M * 4; |
| 226 | + |
| 227 | + const uint strideAinElements = line_stride_matrix_A_in_bytes / 2; |
| 228 | + const uint strideBinElements = line_stride_matrix_B_in_bytes / 4; |
| 229 | + |
| 230 | + size_t sub_block_id_m = get_local_id(0); |
| 231 | + |
| 232 | + uint b_localOffsetInWords = (sub_block_id_m/16)*16 |
| 233 | + + ((((sub_block_id_m)>>0)&1)<<2) |
| 234 | + + ((((sub_block_id_m)>>1)&1)<<3) |
| 235 | + + ((((sub_block_id_m)>>2)&1)<<0) |
| 236 | + + ((((sub_block_id_m)>>3)&1)<<1); |
| 237 | + |
| 238 | + uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)}; |
| 239 | + uint b_globalOffsetInWords00, b_globalOffsetInWords16; |
| 240 | +#ifdef KQV |
| 241 | + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K; |
| 242 | + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K); |
| 243 | + uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2; |
| 244 | + uint subMatrixBStartInElements = depth_B * strideBinElements + row * K; |
| 245 | +#else |
| 246 | + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4; |
| 247 | + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4); |
| 248 | + uint subMatrixAStartInElements = col * strideAinElements + depth_A * K; |
| 249 | + uint subMatrixBStartInElements = row * strideBinElements + depth_B * K; |
| 250 | +#endif |
| 251 | + |
| 252 | + __local float matrix_B_local[1024]; |
| 253 | + |
| 254 | + for (uint step=0; step < K; step+=TILESIZE_K) { |
| 255 | + size_t sub_block_id_m = get_local_id(0); |
| 256 | + regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes); |
| 257 | + |
| 258 | + uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00; |
| 259 | + uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16; |
| 260 | + |
| 261 | + regB.s0123 = vload4(b_coordInWords00/4, matrix_B); |
| 262 | + regB.s4567 = vload4(b_coordInWords16/4, matrix_B); |
| 263 | + |
| 264 | + mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1); |
| 265 | + |
| 266 | + subMatrixAStartInElements += TILESIZE_K; |
| 267 | + subMatrixBStartInElements += TILESIZE_K; |
| 268 | + } |
| 269 | + |
| 270 | + uint subMatrixCStartInElements = depth_B * N * M + row * M + col; |
| 271 | + mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32)); |
| 272 | +} |
| 273 | + |
0 commit comments