@@ -46,6 +46,7 @@ SPDX-License-Identifier: MIT
4646#define  _PackedB_PackedB  3
4747#define  _Accumulator_RowMajor  4
4848#define  _Accumulator_ColumnMajor  5
49+ #define  _PackedA_ColumnMajor  6
4950
5051#define  ATTRIBUTE_AS_GENERIC  __global /* the branch using this will be dead, 
5152                                         however we still need a valid address 
@@ -179,9 +180,13 @@ typedef uint   __attribute__((ext_vector_type(32))) uint32;
179180#define      MATH_8_DIV_4  2
180181#define      MATH_8_DIV_2  4
181182#define      MATH_8_DIV_1  8
183+ #define      MATH_7_DIV_1  7
184+ #define      MATH_6_DIV_1  6
185+ #define      MATH_5_DIV_1  5
182186#define  MATH_4_DIV_4  1
183187#define  MATH_4_DIV_2  2
184188#define  MATH_4_DIV_1  4
189+ #define  MATH_3_DIV_1  3
185190#define      MATH_2_DIV_2  1
186191#define      MATH_2_DIV_1  2
187192#define  MATH_1_DIV_1  1
@@ -225,6 +230,7 @@ typedef uint   __attribute__((ext_vector_type(32))) uint32;
225230#define  SHAPE_CONCAT_VNNI (M , K , vnni_factor ) SHAPE_CONCAT_VNNI__(MATH_MUL(M, vnni_factor), MATH_DIV(K, vnni_factor))
226231
227232#define  SHAPE_PackedA_RowMajor (       M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT(M, K)
233+ #define  SHAPE_PackedA_ColumnMajor (    M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT(M, K)
228234#define  SHAPE_PackedB_RowMajor (       M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
229235#define  SHAPE_PackedB_ColumnMajor (    M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
230236#define  SHAPE_PackedB_PackedB (        M , K , elem_bitwidth , contrib_bitwidth ) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
@@ -407,13 +413,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
407413
408414/* For platforms without SG16 JointMatrix support block2d is not available. The 
409415 * implementation remains empty, will fallthrough to vector implementation. */ 
410- #define  IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_ (element_type ,  elem_bitwidth ,  contrib_type ,  contrib_bitwidth ,  M ,  K ,  WI_rows ,  contrib_M ,  contrib_K ) \
416+ #define  IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_ (p1 ,  p2 ,  p3 ,  p4 ,  p5 ,  p6 ,  p7 ,  p8 ,  p9 ,  p10 ,  p11 ) \
411417  /* not supported, fallthrough */ 
412- #define  IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_ (element_type ,  elem_bitwidth ,  contrib_type ,  contrib_bitwidth ,  M ,  K ,  WI_rows ,  contrib_M ,  contrib_K ) \
418+ #define  IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_ (p1 ,  p2 ,  p3 ,  p4 ,  p5 ,  p6 ,  p7 ,  p8 ,  p9 ,  p10 ,  p11 ) \
413419  /* not supported, fallthrough */ 
414- #define  IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_ (element_type ,  elem_bitwidth ,  contrib_type ,  contrib_bitwidth ,  M ,  K ,  WI_rows ,  contrib_M ,  contrib_K ) \
420+ #define  IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_ (p1 ,  p2 ,  p3 ,  p4 ,  p5 ,  p6 ,  p7 ,  p8 ,  p9 ,  p10 ,  p11 ) \
415421  /* not supported, fallthrough */ 
416- #define  IMPLEMENT_BLOCK2D_STORE (element_type ,  elem_bitwidth ,  contrib_type ,  contrib_bitwidth ,  M ,  K ,  WI_rows ,  contrib_K ) \
422+ #define  IMPLEMENT_BLOCK2D_STORE (p1 ,  p2 ,  p3 ,  p4 ,  p5 ,  p6 ,  p7 ,  p8 ) \
417423  /* not supported, fallthrough */ 
418424
419425// contrib_K - calculated in BLOCK2D loads; contrib_K = K/(contrib_bitwidth/elem_bitwidth); 
@@ -422,7 +428,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
422428
423429#define  MAX_ROW_BYTES_2D_BLOCK_LOAD  64 // maximum per row size in bytes supported by 2D block load
424430
425- #define  IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
431+ // M and K are the original tile dimensions for A and C. 
432+ // M and K are the VNNI-ed dimensions for matrix B (hence the are also aliased as M_VNNI and K_VNNI in some places) 
433+ // Hence, we also have orig_M, which can be used as non-VNNI'ed M dimension of matrix B 
434+ // When smaller elements are packed into bigger types, contrib_M and contrib_K represent M and K dimensions of the packed matrix 
435+ // Not all values (orig_M, contrib_M, contrib_K) are valid for all different layouts and macros, so before using any of it, please, check, 
436+ // if it's valid for the current configuration 
437+ #define  IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
426438  if (contrib_K*sizeof(contrib_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */  \
427439    long  offset  =  as_long (mem ); \
428440    long  baseoffset  =  offset  &  (~0x3f ); /* align to 64-byte */  \
@@ -437,7 +449,8 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
437449    return; \
438450  }
439451
440- #define  IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
452+ // 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose 
453+ #define  IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_ (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
441454  if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */  \
442455    long  offset  =  as_long (mem ); \
443456    long  baseoffset  =  offset  &  (~0x3f ); /* align to 64-byte */  \
@@ -446,29 +459,45 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
446459    int  height  =  16  -  1 ; /* taken from SG16 */  \
447460    long  x  =  (offset  -  baseoffset ) / (sizeof  (contrib_type )); /* in elements */  \
448461    int2  coords  =  (int2 )(x , 0 ); \
449-     /* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */  \
450-     OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(long, int, int, int, int2, int); \
451-     OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
452-     *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
453-     return; \
462+ \
463+     if  (elem_bitwidth  ==  32 ) { \
464+       OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
465+       OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
466+       *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
467+       return; \
468+     } \
469+  \
470+     if (elem_bitwidth == 16 && _##layout == _PackedA_ColumnMajor) { \
471+       OUT_VEC##contrib_M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(long, int, int, int, int2, int); \
472+       OUT_VEC##contrib_M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
473+       *(__private OUT_VEC##WI_rows(u##element_type) *)dst = *(__private OUT_VEC##WI_rows(u##element_type) *)&res; \
474+       return; \
475+     } \
476+  \
477+     if (elem_bitwidth == 16 && _##layout == _PackedB_ColumnMajor) { \
478+       OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
479+       OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
480+       *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
481+       return; \
482+     } \
454483  }
455484
456- #define  IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M ,  K , WI_rows , contrib_M , contrib_K ) \
485+ #define  IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_ (layout ,  element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M_VNNI ,  K_VNNI , WI_rows ,  orig_M , contrib_M , contrib_K ) \
457486  if (contrib_K*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */  \
458487    long  offset  =  as_long (mem ); \
459488    long  baseoffset  =  offset  &  (~0x3f ); /* align to 64-byte */  \
460489    int  width  =  (sizeof  (element_type )) *  stride  -  1 ; /* in bytes */  \
461490    int  pitch  =  width ; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */  \
462-     int  height  =  contrib_M  -  1 ; /* row count */  \
491+     int  height  =  orig_M  -  1 ; /* row count */  \
463492    long  x  =  (offset  -  baseoffset ) / (sizeof  (element_type )); /* in elements */  \
464493    int2  coords  =  (int2 )(x , 0 ); \
465-     OUT_VEC ##M  (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(long, int, int, int, int2, int); \
466-     OUT_VEC##M (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(baseoffset, width, height, pitch, coords, cacheOpt); \
467-     *(__private OUT_VEC##M (u##contrib_type) *)dst = res; \
494+     OUT_VEC ##WI_rows  (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(long, int, int, int, int2, int); \
495+     OUT_VEC##WI_rows (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(baseoffset, width, height, pitch, coords, cacheOpt); \
496+     *(__private OUT_VEC##WI_rows (u##contrib_type) *)dst = res; \
468497    return; \
469498  }
470499
471- #define  IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
500+ #define  IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_ (layout ,  element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ,  orig_M , contrib_M , contrib_K ) \
472501  long offset = as_long(mem); \
473502  int width_size = sizeof (element_type) * width - 1; /* in bytes */  \
474503  int  pitch  =  sizeof  (element_type ) *  stride  -  1 ; /* in bytes */  \
@@ -480,7 +509,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
480509  *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
481510  return;
482511
483- #define  IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_M , contrib_K ) \
512+ #define  IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_ (layout ,  element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ,  orig_M , contrib_M , contrib_K ) \
484513  long offset = as_long(mem); \
485514  int width_size = sizeof (element_type) * width - 1; /* in bytes */  \
486515  int  pitch  =  sizeof  (element_type ) *  stride  -  1 ; /* in bytes */  \
@@ -492,25 +521,29 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
492521  *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
493522  return;
494523
495- #define  IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_ (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M ,  K , WI_rows , contrib_M , contrib_K ) \
524+ #define  IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_ (layout ,  element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M_VNNI ,  K_VNNI , WI_rows ,  orig_M , contrib_M , contrib_K ) \
496525  long offset = as_long(mem); \
497526  int width_size = sizeof (element_type) * width - 1; /* in bytes */  \
498527  int  pitch  =  sizeof  (element_type ) *  stride  -  1 ; /* in bytes */  \
499528  int  height_size  =  height  -  1 ; \
500529  int2  coords  =  (int2 )(x , y ); \
501-   OUT_VEC ##M  (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(long, int, int, int, int2, int); \
502-   OUT_VEC##M (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M )(offset, width_size, height_size, pitch, coords, cacheOpt); \
530+   OUT_VEC ##WI_rows  (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(long, int, int, int, int2, int); \
531+   OUT_VEC##WI_rows (u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M )(offset, width_size, height_size, pitch, coords, cacheOpt); \
503532  *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
504533  return;
505534
506- #define  IMPLEMENT_BLOCK2D_LOAD__ (checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ) \
507-   IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
508-                                     M, K, WI_rows, \
535+ #define  IMPLEMENT_BLOCK2D_LOAD___ (layout , checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , contrib_M , contrib_K ) \
536+   IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
537+                                     M, K, WI_rows, orig_M, contrib_M, contrib_K)
538+ 
539+ #define  IMPLEMENT_BLOCK2D_LOAD__ (layout , checked , sg , order , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows ) \
540+   IMPLEMENT_BLOCK2D_LOAD___(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, \
509541                                    MATH_MUL(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
542+                                     MATH_DIV(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
510543                                    MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
511544
512- #define  IMPLEMENT_BLOCK2D_LOAD (checked , sg , order , element_type , contrib_type , M , K , WI_rows ) \
513-   IMPLEMENT_BLOCK2D_LOAD__(checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
545+ #define  IMPLEMENT_BLOCK2D_LOAD (layout ,  checked , sg , order , element_type , contrib_type , M , K , WI_rows ) \
546+   IMPLEMENT_BLOCK2D_LOAD__(layout,  checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
514547                           M, K, WI_rows)
515548
516549#define  IMPLEMENT_BLOCK2D_STORE_SG16 (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_K ) \
@@ -557,13 +590,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
557590#define  DEFINE_LOAD_BLOCK2D_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
558591    if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL \
559592        && (M == 1 || M == 2 || M == 4 || M == 8 || M == 16 || M == 32) \
560-         && (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32 &&  WI_rows == M)) \
593+         && (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && WI_rows == M)) \
561594        ) { \
562-         IMPLEMENT_BLOCK2D_LOAD(, sg, order##_, element_type, contrib_type, M, K, WI_rows) \
595+         IMPLEMENT_BLOCK2D_LOAD(layout,  , sg, order##_, element_type, contrib_type, M, K, WI_rows) \
563596    }
564597
565598#define  DEFINE_LOAD_CHECKED_BLOCK2D_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
566-     IMPLEMENT_BLOCK2D_LOAD(_CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
599+     IMPLEMENT_BLOCK2D_LOAD(layout,  _CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
567600
568601#define  DEFINE_LOAD_VECTORS_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows , address_space ) \
569602    if (WI_rows >= M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL \
@@ -735,6 +768,10 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR,
735768
736769DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 1 , 32 , ROW_MAJOR , _us , 2 )
737770
771+ // This matrix is represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32 
772+ // so, contrib type is `int` here and we read <4xi32> from memory, but then we use it as <8xi16> 
773+ DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , short , int , 8 , 16 , COL_MAJOR , , 8 )
774+ 
738775/* PackedA load i16 SG16 for sub group size = 32*/ 
739776DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 8 , 16 , ROW_MAJOR , _us , 4 )
740777DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 7 , 16 , ROW_MAJOR , _us , 4 )
0 commit comments