Skip to content

Commit 05aa63e

Browse files
author
Chip Kerchner
committed
More MMA BF16 GEMV code.
1 parent c9ce37d commit 05aa63e

File tree

2 files changed

+372
-53
lines changed

2 files changed

+372
-53
lines changed

kernel/power/sbgemv_common_power10.c

Lines changed: 172 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
#define SBGEMV_COMMON_MMA_C
3030
#include "sbgemv_common.c"
3131

32+
#if defined(_AIX) || defined(__clang__)
33+
#define USE_MERGE_MMA
34+
#endif
35+
3236
FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
3337
{
3438
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
@@ -69,11 +73,13 @@ FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
6973
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
7074
}
7175

76+
#ifndef USE_MERGE_MMA
7277
FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp)
7378
{
7479
vec_mult2_mma(out + 0, in0[0], inp);
7580
vec_mult2_mma(out + 2, in0[1], inp);
7681
}
82+
#endif
7783

7884
FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
7985
{
@@ -96,6 +102,7 @@ FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16
96102
vec_mult2_mma(out, in0, inp);
97103
}
98104

105+
#ifndef USE_MERGE_MMA
99106
FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
100107
{
101108
vec_bf16 in0[4];
@@ -106,6 +113,7 @@ FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16
106113
vec_mult4_mma(&out[0], in0 + 0, inp);
107114
vec_mult4_mma(&out[4], in0 + 2, inp);
108115
}
116+
#endif
109117

110118
FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
111119
{
@@ -120,13 +128,31 @@ FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_al
120128
vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]);
121129
}
122130

131+
#ifndef USE_MERGE_MMA
123132
FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
124133
{
125134
vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
126135
vec_reduce2_mma(&out[2], &temp[8], v_alpha, vy0 + 2);
127136
vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4);
128137
vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6);
129138
}
139+
#else
140+
FORCEINLINE void vec_reduce44_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
141+
{
142+
__builtin_mma_disassemble_acc((void*)temp, &out[0]);
143+
144+
vy0[0] += (temp[0] * v_alpha);
145+
vy0[2] += (temp[1] * v_alpha);
146+
vy0[4] += (temp[2] * v_alpha);
147+
vy0[6] += (temp[3] * v_alpha);
148+
}
149+
150+
FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
151+
{
152+
vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
153+
vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1);
154+
}
155+
#endif
130156

131157
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
132158
{
@@ -166,18 +192,25 @@ FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1
166192
vec_mult2a_mma(out, in0, in1, inp);
167193
}
168194

169-
FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
195+
FORCEINLINE void vec_load4_mma(vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *ina, vec_bf16 *inb)
170196
{
171-
vec_bf16 in0[4], in1[4];
172-
173197
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
174198
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
175199
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
176200
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
201+
}
202+
203+
#ifndef USE_MERGE_MMA
204+
FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
205+
{
206+
vec_bf16 in0[4], in1[4];
207+
208+
vec_load4_mma(in0, in1, ina, inb);
177209

178210
vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp);
179211
vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp);
180212
}
213+
#endif
181214

182215
FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
183216
{
@@ -209,6 +242,48 @@ FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1
209242
vec_mult2b_mma(out + 2, in0[1], in1[1], inp);
210243
}
211244

245+
#ifdef USE_MERGE_MMA
246+
FORCEINLINE void vec_mult1c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
247+
{
248+
vec_bf16 in00 = vec_mergeh(in0, in0);
249+
250+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00);
251+
}
252+
253+
FORCEINLINE void vec_mult2c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
254+
{
255+
vec_bf16 in01 = vec_mergel(in0, in0);
256+
257+
vec_mult1c_mma(&out[0], in0, inp);
258+
259+
__builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01);
260+
}
261+
262+
FORCEINLINE void vec_mult44_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
263+
{
264+
vec_mult2_mma(out, in[0], inp[0]);
265+
vec_mult2c_mma(out, in[1], inp[1]);
266+
}
267+
268+
FORCEINLINE void vec_mult44c_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
269+
{
270+
vec_mult2c_mma(out, in[0], inp[0]);
271+
vec_mult2c_mma(out, in[1], inp[1]);
272+
}
273+
274+
FORCEINLINE void vec_mult44a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
275+
{
276+
vec_mult2a_mma(out, in0[0], in1[0], inp[0]);
277+
vec_mult2b_mma(out, in0[1], in1[1], inp[1]);
278+
}
279+
280+
FORCEINLINE void vec_mult44b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp)
281+
{
282+
vec_mult2b_mma(out, in0[0], in1[0], inp[0]);
283+
vec_mult2b_mma(out, in0[1], in1[1], inp[1]);
284+
}
285+
#endif
286+
212287
FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
213288
{
214289
vec_bf16 in0 = vec_loadN(ina, n);
@@ -225,18 +300,48 @@ FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1
225300
vec_mult2b_mma(out, in0, in1, inp);
226301
}
227302

303+
#ifndef USE_MERGE_MMA
228304
FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
229305
{
230306
vec_bf16 in0[4], in1[4];
231307

232-
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
233-
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
234-
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
235-
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
308+
vec_load4_mma(in0, in1, ina, inb);
236309

237310
vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp);
238311
vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp);
239312
}
313+
#else
314+
FORCEINLINE void vec_load_mult184_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
315+
{
316+
vec_bf16 in0[4];
317+
318+
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
319+
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
320+
321+
vec_mult44_mma(out, in0 + 0, inp + 0);
322+
vec_mult44c_mma(out, in0 + 2, inp + 2);
323+
}
324+
325+
FORCEINLINE void vec_load_mult284a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
326+
{
327+
vec_bf16 in0[4], in1[4];
328+
329+
vec_load4_mma(in0, in1, ina, inb);
330+
331+
vec_mult44a_mma(out, in0 + 0, in1 + 0, inp + 0);
332+
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
333+
}
334+
335+
FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp)
336+
{
337+
vec_bf16 in0[4], in1[4];
338+
339+
vec_load4_mma(in0, in1, ina, inb);
340+
341+
vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0);
342+
vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2);
343+
}
344+
#endif
240345

241346
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
242347
{
@@ -262,4 +367,64 @@ FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0)
262367
vec_store_pair(v_y + 6, vy0 + 6);
263368
}
264369

370+
#ifdef USE_MERGE_MMA
371+
FORCEINLINE void vec_load8_pair(vec_f32 *vy0, vec_f32 *v_y)
372+
{
373+
vec_load4_pair(vy0 + 0, v_y + 0);
374+
vec_load4_pair(vy0 + 8, v_y + 8);
375+
}
376+
377+
FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0)
378+
{
379+
vec_store4_pair(v_y + 0, vy0 + 0);
380+
vec_store4_pair(v_y + 8, vy0 + 8);
381+
}
382+
383+
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
384+
#define VEC_SHIFT(data, shift) vec_sld(data, data, 16 - shift)
385+
#else
386+
#define VEC_SHIFT(data, shift) vec_sld(data, data, shift)
387+
#endif
388+
389+
typedef __vector unsigned int vec_ui32;
390+
391+
static vec_ui32 mask_0 = { 0xffffffff, 0x00000000, 0x00000000, 0x00000000 };
392+
static vec_ui32 mask_1 = { 0x00000000, 0xffffffff, 0x00000000, 0x00000000 };
393+
static vec_ui32 mask_2 = { 0x00000000, 0x00000000, 0xffffffff, 0x00000000 };
394+
static vec_ui32 mask_3 = { 0x00000000, 0x00000000, 0x00000000, 0xffffffff };
395+
396+
FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0)
397+
{
398+
v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)mask_0);
399+
400+
v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 4);
401+
v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 8);
402+
v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 12);
403+
}
404+
405+
FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0)
406+
{
407+
v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)mask_1);
408+
vec_make_mult1(v_x0);
409+
410+
v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 12);
411+
v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 4);
412+
v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 8);
413+
}
414+
415+
FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0)
416+
{
417+
v_x0[10] = vec_and(v_x0[0], (vec_bf16)mask_2);
418+
v_x0[15] = vec_and(v_x0[0], (vec_bf16)mask_3);
419+
vec_make_mult2(v_x0);
420+
421+
v_x0[ 8] = VEC_SHIFT(v_x0[10], 8);
422+
v_x0[ 9] = VEC_SHIFT(v_x0[10], 12);
423+
v_x0[11] = VEC_SHIFT(v_x0[10], 4);
424+
v_x0[12] = VEC_SHIFT(v_x0[15], 4);
425+
v_x0[13] = VEC_SHIFT(v_x0[15], 8);
426+
v_x0[14] = VEC_SHIFT(v_x0[15], 12);
427+
}
428+
#endif
429+
265430
#endif

0 commit comments

Comments
 (0)