@@ -46,7 +46,6 @@ SPDX-License-Identifier: MIT
4646#define _PackedB_PackedB 3
4747#define _Accumulator_RowMajor 4
4848#define _Accumulator_ColumnMajor 5
49- #define _PackedA_ColumnMajor 6
5049
5150#define ATTRIBUTE_AS_GENERIC __global /* the branch using this will be dead,
5251 however we still need a valid address
@@ -180,13 +179,9 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
180179#define MATH_8_DIV_4 2
181180#define MATH_8_DIV_2 4
182181#define MATH_8_DIV_1 8
183- #define MATH_7_DIV_1 7
184- #define MATH_6_DIV_1 6
185- #define MATH_5_DIV_1 5
186182#define MATH_4_DIV_4 1
187183#define MATH_4_DIV_2 2
188184#define MATH_4_DIV_1 4
189- #define MATH_3_DIV_1 3
190185#define MATH_2_DIV_2 1
191186#define MATH_2_DIV_1 2
192187#define MATH_1_DIV_1 1
@@ -230,7 +225,6 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
230225#define SHAPE_CONCAT_VNNI (M , K , vnni_factor ) SHAPE_CONCAT_VNNI__(MATH_MUL(M, vnni_factor), MATH_DIV(K, vnni_factor))
231226
232227#define SHAPE_PackedA_RowMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT(M, K)
233- #define SHAPE_PackedA_ColumnMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT(M, K)
234228#define SHAPE_PackedB_RowMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
235229#define SHAPE_PackedB_ColumnMajor ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
236230#define SHAPE_PackedB_PackedB ( M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
@@ -413,13 +407,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
413407
414408/* For platforms without SG16 JointMatrix support block2d is not available. The
415409 * implementation remains empty, will fallthrough to vector implementation. */
416- #define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_ (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 , p9 , p10 , p11 ) \
410+ #define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
417411 /* not supported, fallthrough */
418- #define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_ (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 , p9 , p10 , p11 ) \
412+ #define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
419413 /* not supported, fallthrough */
420- #define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_ (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 , p9 , p10 , p11 ) \
414+ #define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
421415 /* not supported, fallthrough */
422- #define IMPLEMENT_BLOCK2D_STORE (p1 , p2 , p3 , p4 , p5 , p6 , p7 , p8 ) \
416+ #define IMPLEMENT_BLOCK2D_STORE (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_K ) \
423417 /* not supported, fallthrough */
424418
425419// contrib_K - calculated in BLOCK2D loads; contrib_K = K/(contrib_bitwidth/elem_bitwidth);
@@ -428,13 +422,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
428422
429423#define MAX_ROW_BYTES_2D_BLOCK_LOAD 64 // maximum per row size in bytes supported by 2D block load
430424
431- // M and K are the original tile dimensions for A and C.
432- // M and K are the VNNI-ed dimensions for matrix B (hence the are also aliased as M_VNNI and K_VNNI in some places)
433- // Hence, we also have orig_M, which can be used as non-VNNI'ed M dimension of matrix B
434- // When smaller elements are packed into bigger types, contrib_M and contrib_K represent M and K dimensions of the packed matrix
435- // Not all values (orig_M, contrib_M, contrib_K) are valid for all different layouts and macros, so before using any of it, please, check,
436- // if it's valid for the current configuration
437- #define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
425+ #define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
438426 if (contrib_K*sizeof(contrib_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
439427 long offset = as_long (mem ); \
440428 long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
@@ -449,8 +437,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
449437 return; \
450438 }
451439
452- // 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose
453- #define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
440+ #define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
454441 if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
455442 long offset = as_long (mem ); \
456443 long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
@@ -459,45 +446,29 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
459446 int height = 16 - 1 ; /* taken from SG16 */ \
460447 long x = (offset - baseoffset ) / (sizeof (contrib_type )); /* in elements */ \
461448 int2 coords = (int2 )(x , 0 ); \
462- \
463- if (elem_bitwidth == 32 ) { \
464- OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
465- OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
466- *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
467- return; \
468- } \
469- \
470- if (elem_bitwidth == 16 && _##layout == _PackedA_ColumnMajor) { \
471- OUT_VEC##contrib_M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(long, int, int, int, int2, int); \
472- OUT_VEC##contrib_M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
473- *(__private OUT_VEC##WI_rows(u##element_type) *)dst = *(__private OUT_VEC##WI_rows(u##element_type) *)&res; \
474- return; \
475- } \
476- \
477- if (elem_bitwidth == 16 && _##layout == _PackedB_ColumnMajor) { \
478- OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
479- OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
480- *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
481- return; \
482- } \
449+ /* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
450+ OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(long, int, int, int, int2, int); \
451+ OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
452+ *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
453+ return; \
483454 }
484455
485- #define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M_VNNI , K_VNNI , WI_rows , orig_M , contrib_M , contrib_K ) \
456+ #define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
486457 if (contrib_K*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
487458 long offset = as_long (mem ); \
488459 long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
489460 int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
490461 int pitch = width ; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
491- int height = orig_M - 1 ; /* row count */ \
462+ int height = contrib_M - 1 ; /* row count */ \
492463 long x = (offset - baseoffset ) / (sizeof (element_type )); /* in elements */ \
493464 int2 coords = (int2 )(x , 0 ); \
494- OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(long, int, int, int, int2, int); \
495- OUT_VEC##WI_rows (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(baseoffset, width, height, pitch, coords, cacheOpt); \
496- *(__private OUT_VEC##WI_rows (u##contrib_type) *)dst = res; \
465+ OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(long, int, int, int, int2, int); \
466+ OUT_VEC##M (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(baseoffset, width, height, pitch, coords, cacheOpt); \
467+ *(__private OUT_VEC##M (u##contrib_type) *)dst = res; \
497468 return; \
498469 }
499470
500- #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
471+ #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
501472 long offset = as_long(mem); \
502473 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
503474 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
@@ -509,7 +480,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
509480 *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
510481 return;
511482
512- #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
483+ #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
513484 long offset = as_long(mem); \
514485 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
515486 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
@@ -521,29 +492,25 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
521492 *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
522493 return;
523494
524- #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M_VNNI , K_VNNI , WI_rows , orig_M , contrib_M , contrib_K ) \
495+ #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
525496 long offset = as_long(mem); \
526497 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
527498 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
528499 int height_size = height - 1 ; \
529500 int2 coords = (int2 )(x , y ); \
530- OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(long, int, int, int, int2, int); \
531- OUT_VEC##WI_rows (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(offset, width_size, height_size, pitch, coords, cacheOpt); \
501+ OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(long, int, int, int, int2, int); \
502+ OUT_VEC##M (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(offset, width_size, height_size, pitch, coords, cacheOpt); \
532503 *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
533504 return;
534505
535- #define IMPLEMENT_BLOCK2D_LOAD___ (layout , checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
536- IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
537- M, K, WI_rows, orig_M, contrib_M, contrib_K)
538-
539- #define IMPLEMENT_BLOCK2D_LOAD__ (layout , checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ) \
540- IMPLEMENT_BLOCK2D_LOAD___(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, \
506+ #define IMPLEMENT_BLOCK2D_LOAD__ (checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ) \
507+ IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
508+ M, K, WI_rows, \
541509 MATH_MUL(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
542- MATH_DIV(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
543510 MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
544511
545- #define IMPLEMENT_BLOCK2D_LOAD (layout , checked , sg , order , element_type , contrib_type , M , K , WI_rows ) \
546- IMPLEMENT_BLOCK2D_LOAD__(layout, checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
512+ #define IMPLEMENT_BLOCK2D_LOAD (checked , sg , order , element_type , contrib_type , M , K , WI_rows ) \
513+ IMPLEMENT_BLOCK2D_LOAD__(checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
547514 M, K, WI_rows)
548515
549516#define IMPLEMENT_BLOCK2D_STORE_SG16 (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_K ) \
@@ -590,13 +557,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
590557#define DEFINE_LOAD_BLOCK2D_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
591558 if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL \
592559 && (M == 1 || M == 2 || M == 4 || M == 8 || M == 16 || M == 32) \
593- && (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && WI_rows == M)) \
560+ && (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32 && WI_rows == M)) \
594561 ) { \
595- IMPLEMENT_BLOCK2D_LOAD(layout, , sg, order##_, element_type, contrib_type, M, K, WI_rows) \
562+ IMPLEMENT_BLOCK2D_LOAD(, sg, order##_, element_type, contrib_type, M, K, WI_rows) \
596563 }
597564
598565#define DEFINE_LOAD_CHECKED_BLOCK2D_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
599- IMPLEMENT_BLOCK2D_LOAD(layout, _CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
566+ IMPLEMENT_BLOCK2D_LOAD(_CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
600567
601568#define DEFINE_LOAD_VECTORS_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows , address_space ) \
602569 if (WI_rows >= M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL \
@@ -768,10 +735,6 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR,
768735
769736DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 1 , 32 , ROW_MAJOR , _us , 2 )
770737
771- // This matrix is represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
772- // so, contrib type is `int` here and we read <4xi32> from memory, but then we use it as <8xi16>
773- DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , short , int , 8 , 16 , COL_MAJOR , , 8 )
774-
775738/* PackedA load i16 SG16 for sub group size = 32*/
776739DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 8 , 16 , ROW_MAJOR , _us , 4 )
777740DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 7 , 16 , ROW_MAJOR , _us , 4 )
0 commit comments