@@ -29,6 +29,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
29
#define SBGEMV_COMMON_MMA_C
30
30
#include "sbgemv_common.c"
31
31
32
+ #if defined(_AIX ) || defined(__clang__ )
33
+ #define USE_MERGE_MMA
34
+ #endif
35
+
32
36
FORCEINLINE void vec_load_mult_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 inp )
33
37
{
34
38
vec_bf16 in0 = (vec_bf16 )vec_load_vec (in );
@@ -69,11 +73,13 @@ FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
69
73
__builtin_mma_xvbf16ger2 (& out [1 ], (vec_uc8 )inp , (vec_uc8 )in01 );
70
74
}
71
75
76
+ #ifndef USE_MERGE_MMA
72
77
FORCEINLINE void vec_mult4_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 inp )
73
78
{
74
79
vec_mult2_mma (out + 0 , in0 [0 ], inp );
75
80
vec_mult2_mma (out + 2 , in0 [1 ], inp );
76
81
}
82
+ #endif
77
83
78
84
FORCEINLINE void vec_loadN_mult11_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 inp , BLASLONG n )
79
85
{
@@ -96,6 +102,7 @@ FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16
96
102
vec_mult2_mma (out , in0 , inp );
97
103
}
98
104
105
+ #ifndef USE_MERGE_MMA
99
106
FORCEINLINE void vec_load_mult18_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 inp )
100
107
{
101
108
vec_bf16 in0 [4 ];
@@ -106,6 +113,7 @@ FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16
106
113
vec_mult4_mma (& out [0 ], in0 + 0 , inp );
107
114
vec_mult4_mma (& out [4 ], in0 + 2 , inp );
108
115
}
116
+ #endif
109
117
110
118
FORCEINLINE void vec_reduce1_mma (__vector_quad * out , vec_f32 * temp , vec_f32 v_alpha , vec_f32 * vy0 )
111
119
{
@@ -120,13 +128,31 @@ FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_al
120
128
vec_reduce1_mma (& out [1 ], & temp [4 ], v_alpha , & vy0 [1 ]);
121
129
}
122
130
131
+ #ifndef USE_MERGE_MMA
123
132
FORCEINLINE void vec_reduce8_mma (__vector_quad * out , vec_f32 * temp , vec_f32 v_alpha , vec_f32 * vy0 )
124
133
{
125
134
vec_reduce2_mma (& out [0 ], & temp [0 ], v_alpha , vy0 + 0 );
126
135
vec_reduce2_mma (& out [2 ], & temp [8 ], v_alpha , vy0 + 2 );
127
136
vec_reduce2_mma (& out [4 ], & temp [16 ], v_alpha , vy0 + 4 );
128
137
vec_reduce2_mma (& out [6 ], & temp [24 ], v_alpha , vy0 + 6 );
129
138
}
139
+ #else
140
+ FORCEINLINE void vec_reduce44_mma (__vector_quad * out , vec_f32 * temp , vec_f32 v_alpha , vec_f32 * vy0 )
141
+ {
142
+ __builtin_mma_disassemble_acc ((void * )temp , & out [0 ]);
143
+
144
+ vy0 [0 ] += (temp [0 ] * v_alpha );
145
+ vy0 [2 ] += (temp [1 ] * v_alpha );
146
+ vy0 [4 ] += (temp [2 ] * v_alpha );
147
+ vy0 [6 ] += (temp [3 ] * v_alpha );
148
+ }
149
+
150
+ FORCEINLINE void vec_reduce84_mma (__vector_quad * out , vec_f32 * temp , vec_f32 v_alpha , vec_f32 * vy0 )
151
+ {
152
+ vec_reduce44_mma (& out [0 ], & temp [0 ], v_alpha , vy0 + 0 );
153
+ vec_reduce44_mma (& out [1 ], & temp [4 ], v_alpha , vy0 + 1 );
154
+ }
155
+ #endif
130
156
131
157
FORCEINLINE void vec_mult11a_mma (__vector_quad * out , vec_bf16 in0 , vec_bf16 in1 , vec_bf16 inp )
132
158
{
@@ -166,18 +192,25 @@ FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1
166
192
vec_mult2a_mma (out , in0 , in1 , inp );
167
193
}
168
194
169
- FORCEINLINE void vec_load_mult28a_mma ( __vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 inp )
195
+ FORCEINLINE void vec_load4_mma ( vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * ina , vec_bf16 * inb )
170
196
{
171
- vec_bf16 in0 [4 ], in1 [4 ];
172
-
173
197
vec_load_pair ((vec_f32 * )(in0 + 0 ), (vec_f32 * )(ina + 0 ));
174
198
vec_load_pair ((vec_f32 * )(in1 + 0 ), (vec_f32 * )(inb + 0 ));
175
199
vec_load_pair ((vec_f32 * )(in0 + 2 ), (vec_f32 * )(ina + 2 ));
176
200
vec_load_pair ((vec_f32 * )(in1 + 2 ), (vec_f32 * )(inb + 2 ));
201
+ }
202
+
203
+ #ifndef USE_MERGE_MMA
204
+ FORCEINLINE void vec_load_mult28a_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 inp )
205
+ {
206
+ vec_bf16 in0 [4 ], in1 [4 ];
207
+
208
+ vec_load4_mma (in0 , in1 , ina , inb );
177
209
178
210
vec_mult4a_mma (& out [0 ], in0 + 0 , in1 + 0 , inp );
179
211
vec_mult4a_mma (& out [4 ], in0 + 2 , in1 + 2 , inp );
180
212
}
213
+ #endif
181
214
182
215
FORCEINLINE void vec_loadN_mult22a_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 inp , BLASLONG n )
183
216
{
@@ -209,6 +242,48 @@ FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1
209
242
vec_mult2b_mma (out + 2 , in0 [1 ], in1 [1 ], inp );
210
243
}
211
244
245
+ #ifdef USE_MERGE_MMA
246
+ FORCEINLINE void vec_mult1c_mma (__vector_quad * out , vec_bf16 in0 , vec_bf16 inp )
247
+ {
248
+ vec_bf16 in00 = vec_mergeh (in0 , in0 );
249
+
250
+ __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )inp , (vec_uc8 )in00 );
251
+ }
252
+
253
+ FORCEINLINE void vec_mult2c_mma (__vector_quad * out , vec_bf16 in0 , vec_bf16 inp )
254
+ {
255
+ vec_bf16 in01 = vec_mergel (in0 , in0 );
256
+
257
+ vec_mult1c_mma (& out [0 ], in0 , inp );
258
+
259
+ __builtin_mma_xvbf16ger2pp (& out [1 ], (vec_uc8 )inp , (vec_uc8 )in01 );
260
+ }
261
+
262
+ FORCEINLINE void vec_mult44_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 * inp )
263
+ {
264
+ vec_mult2_mma (out , in [0 ], inp [0 ]);
265
+ vec_mult2c_mma (out , in [1 ], inp [1 ]);
266
+ }
267
+
268
+ FORCEINLINE void vec_mult44c_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 * inp )
269
+ {
270
+ vec_mult2c_mma (out , in [0 ], inp [0 ]);
271
+ vec_mult2c_mma (out , in [1 ], inp [1 ]);
272
+ }
273
+
274
+ FORCEINLINE void vec_mult44a_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * inp )
275
+ {
276
+ vec_mult2a_mma (out , in0 [0 ], in1 [0 ], inp [0 ]);
277
+ vec_mult2b_mma (out , in0 [1 ], in1 [1 ], inp [1 ]);
278
+ }
279
+
280
+ FORCEINLINE void vec_mult44b_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * inp )
281
+ {
282
+ vec_mult2b_mma (out , in0 [0 ], in1 [0 ], inp [0 ]);
283
+ vec_mult2b_mma (out , in0 [1 ], in1 [1 ], inp [1 ]);
284
+ }
285
+ #endif
286
+
212
287
FORCEINLINE void vec_loadN_mult11b_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 inp , BLASLONG n )
213
288
{
214
289
vec_bf16 in0 = vec_loadN (ina , n );
@@ -225,18 +300,48 @@ FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1
225
300
vec_mult2b_mma (out , in0 , in1 , inp );
226
301
}
227
302
303
+ #ifndef USE_MERGE_MMA
228
304
FORCEINLINE void vec_load_mult28b_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 inp )
229
305
{
230
306
vec_bf16 in0 [4 ], in1 [4 ];
231
307
232
- vec_load_pair ((vec_f32 * )(in0 + 0 ), (vec_f32 * )(ina + 0 ));
233
- vec_load_pair ((vec_f32 * )(in1 + 0 ), (vec_f32 * )(inb + 0 ));
234
- vec_load_pair ((vec_f32 * )(in0 + 2 ), (vec_f32 * )(ina + 2 ));
235
- vec_load_pair ((vec_f32 * )(in1 + 2 ), (vec_f32 * )(inb + 2 ));
308
+ vec_load4_mma (in0 , in1 , ina , inb );
236
309
237
310
vec_mult4b_mma (& out [0 ], in0 + 0 , in1 + 0 , inp );
238
311
vec_mult4b_mma (& out [4 ], in0 + 2 , in1 + 2 , inp );
239
312
}
313
+ #else
314
+ FORCEINLINE void vec_load_mult184_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 * inp )
315
+ {
316
+ vec_bf16 in0 [4 ];
317
+
318
+ vec_load_pair ((vec_f32 * )(in0 + 0 ), (vec_f32 * )(in + 0 ));
319
+ vec_load_pair ((vec_f32 * )(in0 + 2 ), (vec_f32 * )(in + 2 ));
320
+
321
+ vec_mult44_mma (out , in0 + 0 , inp + 0 );
322
+ vec_mult44c_mma (out , in0 + 2 , inp + 2 );
323
+ }
324
+
325
+ FORCEINLINE void vec_load_mult284a_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 * inp )
326
+ {
327
+ vec_bf16 in0 [4 ], in1 [4 ];
328
+
329
+ vec_load4_mma (in0 , in1 , ina , inb );
330
+
331
+ vec_mult44a_mma (out , in0 + 0 , in1 + 0 , inp + 0 );
332
+ vec_mult44b_mma (out , in0 + 2 , in1 + 2 , inp + 2 );
333
+ }
334
+
335
+ FORCEINLINE void vec_load_mult284b_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 * inp )
336
+ {
337
+ vec_bf16 in0 [4 ], in1 [4 ];
338
+
339
+ vec_load4_mma (in0 , in1 , ina , inb );
340
+
341
+ vec_mult44b_mma (out , in0 + 0 , in1 + 0 , inp + 0 );
342
+ vec_mult44b_mma (out , in0 + 2 , in1 + 2 , inp + 2 );
343
+ }
344
+ #endif
240
345
241
346
FORCEINLINE void vec_loadN_mult22b_mma (__vector_quad * out , vec_bf16 * ina , vec_bf16 * inb , vec_bf16 inp , BLASLONG n )
242
347
{
@@ -262,4 +367,64 @@ FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0)
262
367
vec_store_pair (v_y + 6 , vy0 + 6 );
263
368
}
264
369
370
+ #ifdef USE_MERGE_MMA
371
+ FORCEINLINE void vec_load8_pair (vec_f32 * vy0 , vec_f32 * v_y )
372
+ {
373
+ vec_load4_pair (vy0 + 0 , v_y + 0 );
374
+ vec_load4_pair (vy0 + 8 , v_y + 8 );
375
+ }
376
+
377
+ FORCEINLINE void vec_store8_pair (vec_f32 * v_y , vec_f32 * vy0 )
378
+ {
379
+ vec_store4_pair (v_y + 0 , vy0 + 0 );
380
+ vec_store4_pair (v_y + 8 , vy0 + 8 );
381
+ }
382
+
383
+ #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
384
+ #define VEC_SHIFT (data , shift ) vec_sld(data, data, 16 - shift)
385
+ #else
386
+ #define VEC_SHIFT (data , shift ) vec_sld(data, data, shift)
387
+ #endif
388
+
389
+ typedef __vector unsigned int vec_ui32 ;
390
+
391
+ static vec_ui32 mask_0 = { 0xffffffff , 0x00000000 , 0x00000000 , 0x00000000 };
392
+ static vec_ui32 mask_1 = { 0x00000000 , 0xffffffff , 0x00000000 , 0x00000000 };
393
+ static vec_ui32 mask_2 = { 0x00000000 , 0x00000000 , 0xffffffff , 0x00000000 };
394
+ static vec_ui32 mask_3 = { 0x00000000 , 0x00000000 , 0x00000000 , 0xffffffff };
395
+
396
+ FORCEINLINE void vec_make_mult1 (vec_bf16 * v_x0 )
397
+ {
398
+ v_x0 [ 0 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_0 );
399
+
400
+ v_x0 [ 1 ] = VEC_SHIFT (v_x0 [ 0 ], 4 );
401
+ v_x0 [ 2 ] = VEC_SHIFT (v_x0 [ 0 ], 8 );
402
+ v_x0 [ 3 ] = VEC_SHIFT (v_x0 [ 0 ], 12 );
403
+ }
404
+
405
+ FORCEINLINE void vec_make_mult2 (vec_bf16 * v_x0 )
406
+ {
407
+ v_x0 [ 5 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_1 );
408
+ vec_make_mult1 (v_x0 );
409
+
410
+ v_x0 [ 4 ] = VEC_SHIFT (v_x0 [ 5 ], 12 );
411
+ v_x0 [ 6 ] = VEC_SHIFT (v_x0 [ 5 ], 4 );
412
+ v_x0 [ 7 ] = VEC_SHIFT (v_x0 [ 5 ], 8 );
413
+ }
414
+
415
+ FORCEINLINE void vec_make_mult4 (vec_bf16 * v_x0 )
416
+ {
417
+ v_x0 [10 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_2 );
418
+ v_x0 [15 ] = vec_and (v_x0 [0 ], (vec_bf16 )mask_3 );
419
+ vec_make_mult2 (v_x0 );
420
+
421
+ v_x0 [ 8 ] = VEC_SHIFT (v_x0 [10 ], 8 );
422
+ v_x0 [ 9 ] = VEC_SHIFT (v_x0 [10 ], 12 );
423
+ v_x0 [11 ] = VEC_SHIFT (v_x0 [10 ], 4 );
424
+ v_x0 [12 ] = VEC_SHIFT (v_x0 [15 ], 4 );
425
+ v_x0 [13 ] = VEC_SHIFT (v_x0 [15 ], 8 );
426
+ v_x0 [14 ] = VEC_SHIFT (v_x0 [15 ], 12 );
427
+ }
428
+ #endif
429
+
265
430
#endif
0 commit comments