|
| 1 | +#pragma once |
| 2 | + |
| 3 | +typedef vector unsigned char vec_t; |
| 4 | +typedef __vector_quad acc_t; |
| 5 | + |
| 6 | +template <typename TA> |
| 7 | +class tinyBLAS_Q0_PPC { |
| 8 | + public: |
| 9 | + tinyBLAS_Q0_PPC(int64_t k, |
| 10 | + const TA *A, int64_t lda, |
| 11 | + const block_q8_0 *B, int64_t ldb, |
| 12 | + float *C, int64_t ldc, |
| 13 | + int ith, int nth); |
| 14 | + |
| 15 | + void matmul(int64_t m, int64_t n); |
| 16 | + void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { |
| 17 | + vec_t A_pack[mc*kc*2]; |
| 18 | + vec_t B_pack[nc*kc*2]; |
| 19 | + int comparray[mc*kc]; |
| 20 | + constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>; |
| 21 | + int64_t ytiles = m / mc; |
| 22 | + int64_t xtiles = n / nc; |
| 23 | + int64_t tiles = xtiles * ytiles; |
| 24 | + int64_t duty = (tiles + nth - 1) / nth; |
| 25 | + int64_t start = duty * ith; |
| 26 | + int64_t end = start + duty; |
| 27 | + if (end > tiles) { |
| 28 | + end = tiles; |
| 29 | + } |
| 30 | + for (int64_t job = start; job < end; ++job) { |
| 31 | + int64_t ii = (job / xtiles) * mc; |
| 32 | + int64_t jj = (job % xtiles) * nc; |
| 33 | + for (int64_t kk = 0; kk < k; kk += kc) { |
| 34 | + if constexpr(is_Ablock_q4) { |
| 35 | + packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray); |
| 36 | + } else { |
| 37 | + packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray); |
| 38 | + } |
| 39 | + packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true); |
| 40 | + KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray); |
| 41 | + } |
| 42 | + } |
| 43 | + } |
| 44 | + |
| 45 | + private: |
| 46 | + inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { |
| 47 | + for (int I = 0; I < RM; I++) { |
| 48 | + for (int J = 0; J < RN; J++) { |
| 49 | + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); |
| 50 | + } |
| 51 | + } |
| 52 | + } |
| 53 | + |
| 54 | + inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { |
| 55 | + for (int I = 0; I < RM; I++) { |
| 56 | + for (int J = 0; J < RN; J++) { |
| 57 | + float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); |
| 58 | + *c_ptr += *((float*)&fin_res[idx+I]+J); |
| 59 | + } |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + template<typename ArrayType> |
| 64 | + inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) { |
| 65 | + vector signed int vec_C[4]; |
| 66 | + vector float CA[4] = {0}; |
| 67 | + vector float res[4] = {0}; |
| 68 | + __builtin_mma_disassemble_acc(vec_C, ACC); |
| 69 | + for (int i = 0; i < 4; i++) { |
| 70 | + CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); |
| 71 | + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); |
| 72 | + fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + inline void process_q4_elements(vector signed char (&c)[2], int* ca) { |
| 77 | + const vector signed char lowMask = vec_splats((signed char)0xF); |
| 78 | + const vector unsigned char v4 = vec_splats((unsigned char)0x4); |
| 79 | + const vector signed char v8 = vec_splats((signed char)0x8); |
| 80 | + vector signed int vsum = {0}; |
| 81 | + vector signed int vsum2 = {0}; |
| 82 | + c[0] = vec_and(c[1], lowMask); |
| 83 | + c[1] = vec_sr(c[1], v4); |
| 84 | + c[0] = vec_sub(c[0], v8); |
| 85 | + c[1] = vec_sub(c[1], v8); |
| 86 | + vsum = vec_sum4s(c[0], vsum); |
| 87 | + vsum2 = vec_sum4s(c[1], vsum2); |
| 88 | + vsum = vec_add(vsum, vsum2); |
| 89 | + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; |
| 90 | + } |
| 91 | + |
| 92 | + template <typename V1, typename V2> |
| 93 | + inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { |
| 94 | + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; |
| 95 | + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; |
| 96 | + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; |
| 97 | + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; |
| 98 | + V2 t1, t2, t3, t4, t5, t6, t7, t8; |
| 99 | + vector unsigned char xor_vector; |
| 100 | + uint8_t flip_vec = 0x80; |
| 101 | + xor_vector = vec_splats(flip_vec); |
| 102 | + t1 = vec_perm(s1, s2, swiz1); |
| 103 | + t2 = vec_perm(s1, s2, swiz2); |
| 104 | + t3 = vec_perm(s3, s4, swiz1); |
| 105 | + t4 = vec_perm(s3, s4, swiz2); |
| 106 | + t5 = vec_perm(t1, t3, swiz3); |
| 107 | + t6 = vec_perm(t1, t3, swiz4); |
| 108 | + t7 = vec_perm(t2, t4, swiz3); |
| 109 | + t8 = vec_perm(t2, t4, swiz4); |
| 110 | + if (flip == true) { |
| 111 | + t5 = vec_xor(t5, xor_vector); |
| 112 | + t6 = vec_xor(t6, xor_vector); |
| 113 | + t7 = vec_xor(t7, xor_vector); |
| 114 | + t8 = vec_xor(t8, xor_vector); |
| 115 | + } |
| 116 | + vec_xst(t5, 0, vecOffset); |
| 117 | + vec_xst(t6, 0, vecOffset+16); |
| 118 | + vec_xst(t7, 0, vecOffset+32); |
| 119 | + vec_xst(t8, 0, vecOffset+48); |
| 120 | + } |
| 121 | + |
| 122 | + template<int RM, int RN> |
| 123 | + inline void kernel(int64_t ii, int64_t jj) { |
| 124 | + if constexpr(RM == 4 && RN == 8) { |
| 125 | + KERNEL_4x8(ii,jj); |
| 126 | + } else if constexpr(RM == 8 && RN == 4) { |
| 127 | + KERNEL_8x4(ii,jj); |
| 128 | + } else if constexpr(RM == 8 && RN == 8) { |
| 129 | + KERNEL_8x8(ii,jj); |
| 130 | + } else { |
| 131 | + assert(false && "RN/RM values not supported"); |
| 132 | + } |
| 133 | + } |
| 134 | + template<int size> |
| 135 | + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray); |
| 136 | + template<typename VA, typename VB> |
| 137 | + void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip); |
| 138 | + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n); |
| 139 | + void KERNEL_4x8(int64_t ii, int64_t jj); |
| 140 | + void KERNEL_8x4(int64_t ii, int64_t jj); |
| 141 | + void KERNEL_8x8(int64_t ii, int64_t jj); |
| 142 | + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN); |
| 143 | + template <int RM, int RN> |
| 144 | + void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n); |
| 145 | + |
| 146 | + void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){ |
| 147 | + for (int I = 0; I<8; I++) { |
| 148 | + float a_scale = unhalf((A+((ii+I)*lda)+blk)->d); |
| 149 | + for (int J = 0; J<4; J++) { |
| 150 | + *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d)); |
| 151 | + *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d)); |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + inline void process_q8_elements(const int8_t *qs, int *ca) { |
| 157 | + vector signed char c1 = vec_xl(0, qs); |
| 158 | + vector signed char c2 = vec_xl(16, qs); |
| 159 | + vector signed int vsum1 = {0}; |
| 160 | + vector signed int vsum2 = {0}; |
| 161 | + vsum1 = vec_sum4s(c1, vsum1); |
| 162 | + vsum2 = vec_sum4s(c2, vsum2); |
| 163 | + vector signed int vsum = vec_add(vsum1, vsum2); |
| 164 | + *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3]; |
| 165 | + } |
| 166 | + |
| 167 | + template<typename VA, typename VB> |
| 168 | + void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) { |
| 169 | + int64_t i, j; |
| 170 | + block_q8_0 *aoffset = NULL; |
| 171 | + VA *vecOffset = NULL; |
| 172 | + block_q8_0* aoffsets[8]; |
| 173 | + __vector_pair arr[8]; |
| 174 | + VB c[8][2] = {0}; |
| 175 | + VB c1[8] = {0}; VB c2[8] = {0}; |
| 176 | + aoffset = const_cast<block_q8_0*>(a); |
| 177 | + vecOffset = vec; |
| 178 | + j = (rows >> 3); |
| 179 | + int index = 0; |
| 180 | + if (j > 0) { |
| 181 | + do { |
| 182 | + for (int it = 0; it < 8; it++) |
| 183 | + aoffsets[it] = aoffset + it*lda; |
| 184 | + aoffset += 8 * lda; |
| 185 | + for (int blk = 0; blk < kc; blk++) { |
| 186 | + for (int it = 0; it < 8; it++) { |
| 187 | + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs); |
| 188 | + __builtin_vsx_disassemble_pair(c[it], &arr[it]); |
| 189 | + c1[it] = c[it][0]; |
| 190 | + c2[it] = c[it][1]; |
| 191 | + if (comparray){ |
| 192 | + process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]); |
| 193 | + } |
| 194 | + } |
| 195 | + vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); |
| 196 | + vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); |
| 197 | + vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); |
| 198 | + vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); |
| 199 | + vecOffset += 256; |
| 200 | + } |
| 201 | + j--; |
| 202 | + index += 8*kc; |
| 203 | + } while(j > 0); |
| 204 | + } |
| 205 | + |
| 206 | + } |
| 207 | + |
| 208 | + void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) { |
| 209 | + int64_t i, j; |
| 210 | + TA *aoffset = NULL; |
| 211 | + int8_t *vecOffset = NULL; |
| 212 | + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; |
| 213 | + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; |
| 214 | + vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; |
| 215 | + vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; |
| 216 | + aoffset = const_cast<TA*>(a); |
| 217 | + vecOffset = vec; |
| 218 | + int index = 0; |
| 219 | + j = (rows >> 3); |
| 220 | + if (j > 0) { |
| 221 | + do { |
| 222 | + aoffset1 = aoffset; |
| 223 | + aoffset2 = aoffset1 + lda; |
| 224 | + aoffset3 = aoffset2 + lda; |
| 225 | + aoffset4 = aoffset3 + lda; |
| 226 | + aoffset5 = aoffset4 + lda; |
| 227 | + aoffset6 = aoffset5 + lda; |
| 228 | + aoffset7 = aoffset6 + lda; |
| 229 | + aoffset8 = aoffset7 + lda; |
| 230 | + aoffset += 8 * lda; |
| 231 | + for (int blk = 0; blk < kc; blk++) { |
| 232 | + c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs)); |
| 233 | + c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs)); |
| 234 | + c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs)); |
| 235 | + c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs)); |
| 236 | + c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs)); |
| 237 | + c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs)); |
| 238 | + c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs)); |
| 239 | + c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs)); |
| 240 | + |
| 241 | + process_q4_elements(c1, &comparray[index + 8*blk+0]); |
| 242 | + process_q4_elements(c2, &comparray[index + 8*blk+1]); |
| 243 | + process_q4_elements(c3, &comparray[index + 8*blk+2]); |
| 244 | + process_q4_elements(c4, &comparray[index + 8*blk+3]); |
| 245 | + process_q4_elements(c5, &comparray[index + 8*blk+4]); |
| 246 | + process_q4_elements(c6, &comparray[index + 8*blk+5]); |
| 247 | + process_q4_elements(c7, &comparray[index + 8*blk+6]); |
| 248 | + process_q4_elements(c8, &comparray[index + 8*blk+7]); |
| 249 | + vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false); |
| 250 | + vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); |
| 251 | + vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); |
| 252 | + vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); |
| 253 | + vecOffset += 256; |
| 254 | + } |
| 255 | + j--; |
| 256 | + index += 8*kc; |
| 257 | + } while (j > 0); |
| 258 | + } |
| 259 | + } |
| 260 | + |
| 261 | + void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) { |
| 262 | + acc_t acc[8]; |
| 263 | + for (int i = 0; i < mc ; i += 8) { |
| 264 | + for (int j = 0; j < nc; j += 8) { |
| 265 | + vector float fin_res[16] = {0}; |
| 266 | + vector float vs[16] = {0}; |
| 267 | + for (int64_t kk = 0; kk < kc; kk+=2) { |
| 268 | + for (int x = 0; x < 8; x++) { |
| 269 | + __builtin_mma_xxsetaccz(&acc[x]); |
| 270 | + } |
| 271 | + int A_block_idx = (i/8)*(16*kc) + kk*16; |
| 272 | + int B_block_idx = (j/8)*(16*kc)+ kk*16; |
| 273 | + vec_t *A_block = &vec_A[A_block_idx]; |
| 274 | + vec_t *B_block = &vec_B[B_block_idx]; |
| 275 | + for (int x = 0; x < 8; x++) { |
| 276 | + __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]); |
| 277 | + __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]); |
| 278 | + __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]); |
| 279 | + __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]); |
| 280 | + } |
| 281 | + compute_scale(ii+i, jj+j, l+kk, vs); |
| 282 | + int c_index = (i/8)*(8*kc)+ kk*8; |
| 283 | + int* c_block = &comparray[c_index]; |
| 284 | + compute(&acc[0], 0, 0, c_block, vs, fin_res); |
| 285 | + compute(&acc[1], 4, 4, c_block, vs, fin_res); |
| 286 | + compute(&acc[2], 0, 8, c_block, vs, fin_res); |
| 287 | + compute(&acc[3], 4, 12, c_block, vs, fin_res); |
| 288 | + |
| 289 | + A_block_idx = (i/8)*(16*kc) + (kk+1)*16; |
| 290 | + B_block_idx = (j/8)*(16*kc)+ (kk+1)*16; |
| 291 | + A_block = &vec_A[A_block_idx]; |
| 292 | + B_block = &vec_B[B_block_idx]; |
| 293 | + for (int x = 0; x < 8; x++) { |
| 294 | + __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]); |
| 295 | + __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]); |
| 296 | + __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]); |
| 297 | + __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]); |
| 298 | + } |
| 299 | + compute_scale(ii+i, jj+j, l+kk+1, vs); |
| 300 | + c_index = (i/8)*(8*kc)+ (kk+1)*8; |
| 301 | + c_block = &comparray[c_index]; |
| 302 | + compute(&acc[4], 0, 0, c_block, vs, fin_res); |
| 303 | + compute(&acc[5], 4, 4, c_block, vs, fin_res); |
| 304 | + compute(&acc[6], 0, 8, c_block, vs, fin_res); |
| 305 | + compute(&acc[7], 4, 12, c_block, vs, fin_res); |
| 306 | + |
| 307 | + } |
| 308 | + if (l == 0) { |
| 309 | + save_res(ii+i, jj+j, 0, fin_res); |
| 310 | + save_res(ii+i+4, jj+j, 4, fin_res); |
| 311 | + save_res(ii+i, jj+j+4, 8, fin_res); |
| 312 | + save_res(ii+i+4, jj+j+4, 12, fin_res); |
| 313 | + } else { |
| 314 | + add_save_res(ii+i, jj+j, 0, fin_res); |
| 315 | + add_save_res(ii+i+4, jj+j, 4, fin_res); |
| 316 | + add_save_res(ii+i, jj+j+4, 8, fin_res); |
| 317 | + add_save_res(ii+i+4, jj+j+4, 12, fin_res); |
| 318 | + } |
| 319 | + } |
| 320 | + } |
| 321 | + } |
| 322 | + |
| 323 | + const TA *const A; |
| 324 | + const block_q8_0 *const B; |
| 325 | + float *C; |
| 326 | + const int64_t k; |
| 327 | + int64_t kc; |
| 328 | + const int64_t lda; |
| 329 | + const int64_t ldb; |
| 330 | + const int64_t ldc; |
| 331 | + const int ith; |
| 332 | + const int nth; |
| 333 | +}; |
0 commit comments