@@ -46,6 +46,7 @@ SPDX-License-Identifier: MIT
4646#define _PackedB_PackedB 3
4747#define _Accumulator_RowMajor 4
4848#define _Accumulator_ColumnMajor 5
49+ #define _PackedA_ColumnMajor 6
4950
5051#define ATTRIBUTE_AS_GENERIC __global /* the branch using this will be dead,
5152 however we still need a valid address
@@ -179,9 +180,13 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
179180#define MATH_8_DIV_4 2
180181#define MATH_8_DIV_2 4
181182#define MATH_8_DIV_1 8
183+ #define MATH_7_DIV_1 7
184+ #define MATH_6_DIV_1 6
185+ #define MATH_5_DIV_1 5
182186#define MATH_4_DIV_4 1
183187#define MATH_4_DIV_2 2
184188#define MATH_4_DIV_1 4
189+ #define MATH_3_DIV_1 3
185190#define MATH_2_DIV_2 1
186191#define MATH_2_DIV_1 2
187192#define MATH_1_DIV_1 1
@@ -225,6 +230,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
225230#define SHAPE_CONCAT_VNNI (M , K , vnni_factor ) SHAPE_CONCAT_VNNI__(MATH_MUL(M, vnni_factor), MATH_DIV(K, vnni_factor))
226231
227232#define SHAPE_PackedA_RowMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT(M, K)
233+ #define SHAPE_PackedA_ColumnMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT(M, K)
228234#define SHAPE_PackedB_RowMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
229235#define SHAPE_PackedB_ColumnMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
230236#define SHAPE_PackedB_PackedB ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
@@ -407,13 +413,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
407413
408414/* For platforms without SG16 JointMatrix support block2d is not available. The
409415 * implementation remains empty, will fallthrough to vector implementation. */
410- #define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
416+ #define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_ (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 , p9 , p10 , p11 ) \
411417 /* not supported, fallthrough */
412- #define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
418+ #define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_ (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 , p9 , p10 , p11 ) \
413419 /* not supported, fallthrough */
414- #define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
420+ #define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_ (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 , p9 , p10 , p11 ) \
415421 /* not supported, fallthrough */
416- #define IMPLEMENT_BLOCK2D_STORE (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_K ) \
422+ #define IMPLEMENT_BLOCK2D_STORE (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 ) \
417423 /* not supported, fallthrough */
418424
419425// contrib_K - calculated in BLOCK2D loads; contrib_K = K/(contrib_bitwidth/elem_bitwidth);
@@ -422,7 +428,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
422428
423429#define MAX_ROW_BYTES_2D_BLOCK_LOAD 64 // maximum per row size in bytes supported by 2D block load
424430
425- #define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
431+ // M and K are the original tile dimensions for A and C.
432+ // M and K are the VNNI-ed dimensions for matrix B (hence the are also aliased as M_VNNI and K_VNNI in some places)
433+ // Hence, we also have orig_M, which can be used as non-VNNI'ed M dimension of matrix B
434+ // When smaller elements are packed into bigger types, contrib_M and contrib_K represent M and K dimensions of the packed matrix
435+ // Not all values (orig_M, contrib_M, contrib_K) are valid for all different layouts and macros, so before using any of it, please, check,
436+ // if it's valid for the current configuration
437+ #define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
426438 if (contrib_K*sizeof(contrib_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
427439 long offset = as_long (mem ); \
428440 long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
@@ -437,7 +449,8 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
437449 return; \
438450 }
439451
440- #define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
452+ // 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose
453+ #define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
441454 if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
442455 long offset = as_long (mem ); \
443456 long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
@@ -446,29 +459,45 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
446459 int height = 16 - 1 ; /* taken from SG16 */ \
447460 long x = (offset - baseoffset ) / (sizeof (contrib_type )); /* in elements */ \
448461 int2 coords = (int2 )(x , 0 ); \
449- /* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
450- OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(long, int, int, int, int2, int); \
451- OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
452- *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
453- return; \
462+ \
463+ if (elem_bitwidth == 32 ) { \
464+ OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
465+ OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
466+ *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
467+ return; \
468+ } \
469+ \
470+ if (elem_bitwidth == 16 && _##layout == _PackedA_ColumnMajor) { \
471+ OUT_VEC##contrib_M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(long, int, int, int, int2, int); \
472+ OUT_VEC##contrib_M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
473+ *(__private OUT_VEC##WI_rows(u##element_type) *)dst = *(__private OUT_VEC##WI_rows(u##element_type) *)&res; \
474+ return; \
475+ } \
476+ \
477+ if (elem_bitwidth == 16 && _##layout == _PackedB_ColumnMajor) { \
478+ OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
479+ OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
480+ *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
481+ return; \
482+ } \
454483 }
455484
456- #define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
485+ #define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M_VNNI , K_VNNI , WI_rows , orig_M , contrib_M , contrib_K ) \
457486 if (contrib_K*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
458487 long offset = as_long (mem ); \
459488 long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
460489 int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
461490 int pitch = width ; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
462- int height = contrib_M - 1 ; /* row count */ \
491+ int height = orig_M - 1 ; /* row count */ \
463492 long x = (offset - baseoffset ) / (sizeof (element_type )); /* in elements */ \
464493 int2 coords = (int2 )(x , 0 ); \
465- OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(long, int, int, int, int2, int); \
466- OUT_VEC##M (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(baseoffset, width, height, pitch, coords, cacheOpt); \
467- *(__private OUT_VEC##M (u##contrib_type) *)dst = res; \
494+ OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(long, int, int, int, int2, int); \
495+ OUT_VEC##WI_rows (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(baseoffset, width, height, pitch, coords, cacheOpt); \
496+ *(__private OUT_VEC##WI_rows (u##contrib_type) *)dst = res; \
468497 return; \
469498 }
470499
471- #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
500+ #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
472501 long offset = as_long(mem); \
473502 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
474503 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
@@ -480,7 +509,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
480509 *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
481510 return;
482511
483- #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
512+ #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
484513 long offset = as_long(mem); \
485514 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
486515 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
@@ -492,25 +521,29 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
492521 *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
493522 return;
494523
495- #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
524+ #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M_VNNI , K_VNNI , WI_rows , orig_M , contrib_M , contrib_K ) \
496525 long offset = as_long(mem); \
497526 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
498527 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
499528 int height_size = height - 1 ; \
500529 int2 coords = (int2 )(x , y ); \
501- OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(long, int, int, int, int2, int); \
502- OUT_VEC##M (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(offset, width_size, height_size, pitch, coords, cacheOpt); \
530+ OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(long, int, int, int, int2, int); \
531+ OUT_VEC##WI_rows (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(offset, width_size, height_size, pitch, coords, cacheOpt); \
503532 *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
504533 return;
505534
506- #define IMPLEMENT_BLOCK2D_LOAD__ (checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ) \
507- IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
508- M, K, WI_rows, \
535+ #define IMPLEMENT_BLOCK2D_LOAD___ (layout , checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
536+ IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
537+ M, K, WI_rows, orig_M, contrib_M, contrib_K)
538+
539+ #define IMPLEMENT_BLOCK2D_LOAD__ (layout , checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ) \
540+ IMPLEMENT_BLOCK2D_LOAD___(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, \
509541 MATH_MUL(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
542+ MATH_DIV(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
510543 MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
511544
512- #define IMPLEMENT_BLOCK2D_LOAD (checked , sg , order , element_type , contrib_type , M , K , WI_rows ) \
513- IMPLEMENT_BLOCK2D_LOAD__(checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
545+ #define IMPLEMENT_BLOCK2D_LOAD (layout , checked , sg , order , element_type , contrib_type , M , K , WI_rows ) \
546+ IMPLEMENT_BLOCK2D_LOAD__(layout, checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
514547 M, K, WI_rows)
515548
516549#define IMPLEMENT_BLOCK2D_STORE_SG16 (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_K ) \
@@ -557,13 +590,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
557590#define DEFINE_LOAD_BLOCK2D_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
558591 if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL \
559592 && (M == 1 || M == 2 || M == 4 || M == 8 || M == 16 || M == 32) \
560- && (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32 && WI_rows == M)) \
593+ && (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && WI_rows == M)) \
561594 ) { \
562- IMPLEMENT_BLOCK2D_LOAD(, sg, order##_, element_type, contrib_type, M, K, WI_rows) \
595+ IMPLEMENT_BLOCK2D_LOAD(layout, , sg, order##_, element_type, contrib_type, M, K, WI_rows) \
563596 }
564597
565598#define DEFINE_LOAD_CHECKED_BLOCK2D_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
566- IMPLEMENT_BLOCK2D_LOAD(_CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
599+ IMPLEMENT_BLOCK2D_LOAD(layout, _CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
567600
568601#define DEFINE_LOAD_VECTORS_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows , address_space ) \
569602 if (WI_rows >= M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL \
@@ -735,6 +768,10 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR,
735768
736769DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 1 , 32 , ROW_MAJOR , _us , 2 )
737770
771+ // This matrix is represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
772+ // so, contrib type is `int` here and we read <4xi32> from memory, but then we use it as <8xi16>
773+ DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , short , int , 8 , 16 , COL_MAJOR , , 8 )
774+
738775/* PackedA load i16 SG16 for sub group size = 32*/
739776DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 8 , 16 , ROW_MAJOR , _us , 4 )
740777DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 7 , 16 , ROW_MAJOR , _us , 4 )
0 commit comments