Skip to content

Commit 3a0d105

Browse files
authored
Q4/Q8 Tiled Gemm Optimization. (#16999)
1 parent 6648989 commit 3a0d105

File tree

3 files changed

+390
-125
lines changed

3 files changed

+390
-125
lines changed
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
#pragma once
2+
3+
typedef vector unsigned char vec_t;
4+
typedef __vector_quad acc_t;
5+
6+
template <typename TA>
7+
class tinyBLAS_Q0_PPC {
8+
public:
9+
tinyBLAS_Q0_PPC(int64_t k,
10+
const TA *A, int64_t lda,
11+
const block_q8_0 *B, int64_t ldb,
12+
float *C, int64_t ldc,
13+
int ith, int nth);
14+
15+
void matmul(int64_t m, int64_t n);
16+
void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
17+
vec_t A_pack[mc*kc*2];
18+
vec_t B_pack[nc*kc*2];
19+
int comparray[mc*kc];
20+
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
21+
int64_t ytiles = m / mc;
22+
int64_t xtiles = n / nc;
23+
int64_t tiles = xtiles * ytiles;
24+
int64_t duty = (tiles + nth - 1) / nth;
25+
int64_t start = duty * ith;
26+
int64_t end = start + duty;
27+
if (end > tiles) {
28+
end = tiles;
29+
}
30+
for (int64_t job = start; job < end; ++job) {
31+
int64_t ii = (job / xtiles) * mc;
32+
int64_t jj = (job % xtiles) * nc;
33+
for (int64_t kk = 0; kk < k; kk += kc) {
34+
if constexpr(is_Ablock_q4) {
35+
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
36+
} else {
37+
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
38+
}
39+
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
40+
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
41+
}
42+
}
43+
}
44+
45+
private:
46+
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
47+
for (int I = 0; I < RM; I++) {
48+
for (int J = 0; J < RN; J++) {
49+
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
50+
}
51+
}
52+
}
53+
54+
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
55+
for (int I = 0; I < RM; I++) {
56+
for (int J = 0; J < RN; J++) {
57+
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
58+
*c_ptr += *((float*)&fin_res[idx+I]+J);
59+
}
60+
}
61+
}
62+
63+
template<typename ArrayType>
64+
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
65+
vector signed int vec_C[4];
66+
vector float CA[4] = {0};
67+
vector float res[4] = {0};
68+
__builtin_mma_disassemble_acc(vec_C, ACC);
69+
for (int i = 0; i < 4; i++) {
70+
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
71+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
72+
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
73+
}
74+
}
75+
76+
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
77+
const vector signed char lowMask = vec_splats((signed char)0xF);
78+
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
79+
const vector signed char v8 = vec_splats((signed char)0x8);
80+
vector signed int vsum = {0};
81+
vector signed int vsum2 = {0};
82+
c[0] = vec_and(c[1], lowMask);
83+
c[1] = vec_sr(c[1], v4);
84+
c[0] = vec_sub(c[0], v8);
85+
c[1] = vec_sub(c[1], v8);
86+
vsum = vec_sum4s(c[0], vsum);
87+
vsum2 = vec_sum4s(c[1], vsum2);
88+
vsum = vec_add(vsum, vsum2);
89+
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
90+
}
91+
92+
template <typename V1, typename V2>
93+
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
94+
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
95+
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
96+
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
97+
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
98+
V2 t1, t2, t3, t4, t5, t6, t7, t8;
99+
vector unsigned char xor_vector;
100+
uint8_t flip_vec = 0x80;
101+
xor_vector = vec_splats(flip_vec);
102+
t1 = vec_perm(s1, s2, swiz1);
103+
t2 = vec_perm(s1, s2, swiz2);
104+
t3 = vec_perm(s3, s4, swiz1);
105+
t4 = vec_perm(s3, s4, swiz2);
106+
t5 = vec_perm(t1, t3, swiz3);
107+
t6 = vec_perm(t1, t3, swiz4);
108+
t7 = vec_perm(t2, t4, swiz3);
109+
t8 = vec_perm(t2, t4, swiz4);
110+
if (flip == true) {
111+
t5 = vec_xor(t5, xor_vector);
112+
t6 = vec_xor(t6, xor_vector);
113+
t7 = vec_xor(t7, xor_vector);
114+
t8 = vec_xor(t8, xor_vector);
115+
}
116+
vec_xst(t5, 0, vecOffset);
117+
vec_xst(t6, 0, vecOffset+16);
118+
vec_xst(t7, 0, vecOffset+32);
119+
vec_xst(t8, 0, vecOffset+48);
120+
}
121+
122+
template<int RM, int RN>
123+
inline void kernel(int64_t ii, int64_t jj) {
124+
if constexpr(RM == 4 && RN == 8) {
125+
KERNEL_4x8(ii,jj);
126+
} else if constexpr(RM == 8 && RN == 4) {
127+
KERNEL_8x4(ii,jj);
128+
} else if constexpr(RM == 8 && RN == 8) {
129+
KERNEL_8x8(ii,jj);
130+
} else {
131+
assert(false && "RN/RM values not supported");
132+
}
133+
}
134+
template<int size>
135+
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
136+
template<typename VA, typename VB>
137+
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
138+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
139+
void KERNEL_4x8(int64_t ii, int64_t jj);
140+
void KERNEL_8x4(int64_t ii, int64_t jj);
141+
void KERNEL_8x8(int64_t ii, int64_t jj);
142+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
143+
template <int RM, int RN>
144+
void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
145+
146+
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
147+
for (int I = 0; I<8; I++) {
148+
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
149+
for (int J = 0; J<4; J++) {
150+
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
151+
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
152+
}
153+
}
154+
}
155+
156+
inline void process_q8_elements(const int8_t *qs, int *ca) {
157+
vector signed char c1 = vec_xl(0, qs);
158+
vector signed char c2 = vec_xl(16, qs);
159+
vector signed int vsum1 = {0};
160+
vector signed int vsum2 = {0};
161+
vsum1 = vec_sum4s(c1, vsum1);
162+
vsum2 = vec_sum4s(c2, vsum2);
163+
vector signed int vsum = vec_add(vsum1, vsum2);
164+
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
165+
}
166+
167+
template<typename VA, typename VB>
168+
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
169+
int64_t i, j;
170+
block_q8_0 *aoffset = NULL;
171+
VA *vecOffset = NULL;
172+
block_q8_0* aoffsets[8];
173+
__vector_pair arr[8];
174+
VB c[8][2] = {0};
175+
VB c1[8] = {0}; VB c2[8] = {0};
176+
aoffset = const_cast<block_q8_0*>(a);
177+
vecOffset = vec;
178+
j = (rows >> 3);
179+
int index = 0;
180+
if (j > 0) {
181+
do {
182+
for (int it = 0; it < 8; it++)
183+
aoffsets[it] = aoffset + it*lda;
184+
aoffset += 8 * lda;
185+
for (int blk = 0; blk < kc; blk++) {
186+
for (int it = 0; it < 8; it++) {
187+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
188+
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
189+
c1[it] = c[it][0];
190+
c2[it] = c[it][1];
191+
if (comparray){
192+
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
193+
}
194+
}
195+
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
196+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
197+
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
198+
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
199+
vecOffset += 256;
200+
}
201+
j--;
202+
index += 8*kc;
203+
} while(j > 0);
204+
}
205+
206+
}
207+
208+
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
209+
int64_t i, j;
210+
TA *aoffset = NULL;
211+
int8_t *vecOffset = NULL;
212+
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
213+
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
214+
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
215+
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
216+
aoffset = const_cast<TA*>(a);
217+
vecOffset = vec;
218+
int index = 0;
219+
j = (rows >> 3);
220+
if (j > 0) {
221+
do {
222+
aoffset1 = aoffset;
223+
aoffset2 = aoffset1 + lda;
224+
aoffset3 = aoffset2 + lda;
225+
aoffset4 = aoffset3 + lda;
226+
aoffset5 = aoffset4 + lda;
227+
aoffset6 = aoffset5 + lda;
228+
aoffset7 = aoffset6 + lda;
229+
aoffset8 = aoffset7 + lda;
230+
aoffset += 8 * lda;
231+
for (int blk = 0; blk < kc; blk++) {
232+
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
233+
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
234+
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
235+
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
236+
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
237+
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
238+
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
239+
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
240+
241+
process_q4_elements(c1, &comparray[index + 8*blk+0]);
242+
process_q4_elements(c2, &comparray[index + 8*blk+1]);
243+
process_q4_elements(c3, &comparray[index + 8*blk+2]);
244+
process_q4_elements(c4, &comparray[index + 8*blk+3]);
245+
process_q4_elements(c5, &comparray[index + 8*blk+4]);
246+
process_q4_elements(c6, &comparray[index + 8*blk+5]);
247+
process_q4_elements(c7, &comparray[index + 8*blk+6]);
248+
process_q4_elements(c8, &comparray[index + 8*blk+7]);
249+
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
250+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
251+
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
252+
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
253+
vecOffset += 256;
254+
}
255+
j--;
256+
index += 8*kc;
257+
} while (j > 0);
258+
}
259+
}
260+
261+
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
262+
acc_t acc[8];
263+
for (int i = 0; i < mc ; i += 8) {
264+
for (int j = 0; j < nc; j += 8) {
265+
vector float fin_res[16] = {0};
266+
vector float vs[16] = {0};
267+
for (int64_t kk = 0; kk < kc; kk+=2) {
268+
for (int x = 0; x < 8; x++) {
269+
__builtin_mma_xxsetaccz(&acc[x]);
270+
}
271+
int A_block_idx = (i/8)*(16*kc) + kk*16;
272+
int B_block_idx = (j/8)*(16*kc)+ kk*16;
273+
vec_t *A_block = &vec_A[A_block_idx];
274+
vec_t *B_block = &vec_B[B_block_idx];
275+
for (int x = 0; x < 8; x++) {
276+
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
277+
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
278+
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
279+
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
280+
}
281+
compute_scale(ii+i, jj+j, l+kk, vs);
282+
int c_index = (i/8)*(8*kc)+ kk*8;
283+
int* c_block = &comparray[c_index];
284+
compute(&acc[0], 0, 0, c_block, vs, fin_res);
285+
compute(&acc[1], 4, 4, c_block, vs, fin_res);
286+
compute(&acc[2], 0, 8, c_block, vs, fin_res);
287+
compute(&acc[3], 4, 12, c_block, vs, fin_res);
288+
289+
A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
290+
B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
291+
A_block = &vec_A[A_block_idx];
292+
B_block = &vec_B[B_block_idx];
293+
for (int x = 0; x < 8; x++) {
294+
__builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
295+
__builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
296+
__builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
297+
__builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
298+
}
299+
compute_scale(ii+i, jj+j, l+kk+1, vs);
300+
c_index = (i/8)*(8*kc)+ (kk+1)*8;
301+
c_block = &comparray[c_index];
302+
compute(&acc[4], 0, 0, c_block, vs, fin_res);
303+
compute(&acc[5], 4, 4, c_block, vs, fin_res);
304+
compute(&acc[6], 0, 8, c_block, vs, fin_res);
305+
compute(&acc[7], 4, 12, c_block, vs, fin_res);
306+
307+
}
308+
if (l == 0) {
309+
save_res(ii+i, jj+j, 0, fin_res);
310+
save_res(ii+i+4, jj+j, 4, fin_res);
311+
save_res(ii+i, jj+j+4, 8, fin_res);
312+
save_res(ii+i+4, jj+j+4, 12, fin_res);
313+
} else {
314+
add_save_res(ii+i, jj+j, 0, fin_res);
315+
add_save_res(ii+i+4, jj+j, 4, fin_res);
316+
add_save_res(ii+i, jj+j+4, 8, fin_res);
317+
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
318+
}
319+
}
320+
}
321+
}
322+
323+
const TA *const A;
324+
const block_q8_0 *const B;
325+
float *C;
326+
const int64_t k;
327+
int64_t kc;
328+
const int64_t lda;
329+
const int64_t ldb;
330+
const int64_t ldc;
331+
const int ith;
332+
const int nth;
333+
};

0 commit comments

Comments
 (0)