@@ -690,12 +690,17 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
690690 return; \
691691 }
692692
693- #define DEFINE_LOAD_SCALAR_IMPL (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , order , WI_rows ) \
694- contrib_type *ptr = (contrib_type *)mem; \
693+ #define DEFINE_LOAD_SCALAR_IMPL (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , order , WI_rows ) \
695694 int slid = get_sub_group_local_id(); \
696695 int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
697- long packed_stride = stride / pack_factor; \
698696 int sg_cols = K / pack_factor; \
697+ if (_##layout == _PackedA_ColumnMajor && elem_bitwidth == 8 && contrib_bitwidth == 16) { \
698+ for (int i = 0; i < M; i++) \
699+ dst[i] = mem[((slid * pack_factor) % sg_size) * stride + (slid / sg_cols) + (i & ~1) + (i % pack_factor) * stride]; \
700+ return; \
701+ } \
702+ contrib_type *ptr = (contrib_type *)mem; \
703+ long packed_stride = stride / pack_factor; \
699704 __private contrib_type *wi_contrib = (__private contrib_type *)dst; \
700705 if (order == _VNNI_TX) { \
701706 GATHER_LOAD_PACK_##elem_bitwidth(element_type, M, wi_contrib, slid, stride) \
@@ -727,20 +732,20 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
727732 } else { \
728733 DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_LOCAL) \
729734 } \
730- DEFINE_LOAD_SCALAR_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
735+ DEFINE_LOAD_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
731736 }
732737#define DEFINE_LOAD_IMPL_AS_LOCAL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
733738 INLINE void MANGLE_LOAD_NAME_AS_LOCAL(layout, sg, elem_bitwidth, shape, WI_rows) (__private char *dst, char *mem, long stride, int cacheOpt) { \
734739 int sg_size = get_sub_group_size(); \
735740 DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_LOCAL) \
736- DEFINE_LOAD_SCALAR_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
741+ DEFINE_LOAD_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
737742 }
738743#define DEFINE_LOAD_IMPL_AS_GLOBAL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows ) \
739744 INLINE void MANGLE_LOAD_NAME_AS_GLOBAL(layout, sg, elem_bitwidth, shape, WI_rows) (__private char *dst, char *mem, long stride, int cacheOpt) { \
740745 int sg_size = get_sub_group_size(); \
741746 DEFINE_LOAD_BLOCK2D_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows) \
742747 DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_GLOBAL) \
743- DEFINE_LOAD_SCALAR_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
748+ DEFINE_LOAD_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
744749 }
745750
746751#define DEFINE_LOAD_CHECKED_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , M , K , shape , order , WI_rows ) \
@@ -836,13 +841,18 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR,
836841
837842DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 1 , 32 , ROW_MAJOR , _us , 2 )
838843
839- // This matrix is represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
844+ // PackedA Column Major -----
845+ //
846+ // These matrices are represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
840847// so, contrib type is `int` here and we read <4xi32> from memory, but then we use it as <8xi16>
841848DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , short , int , 8 , 16 , COL_MAJOR , , 8 )
842849DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , char , int , 8 , 32 , COL_MAJOR , , 8 )
843850
844- // PackedA load i16 SG16 subgroup size 32
851+ // Sub-group size 32
845852DEFINE_LOAD (PackedA_ColumnMajor , _SG16 , short , short , 8 , 16 , COL_MAJOR , , 4 )
853+ DEFINE_LOAD (PackedA_ColumnMajor , _SG16 , char , short , 8 , 32 , COL_MAJOR , , 4 )
854+ //
855+ // end of PackedA Column Major -----
846856
847857/* PackedA load i16 SG16 for sub group size = 32*/
848858DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 8 , 16 , ROW_MAJOR , _us , 4 )
@@ -910,6 +920,7 @@ DEFINE_LOAD_AND_CHECKED(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR,
910920DEFINE_LOAD_AND_CHECKED (PackedB_RowMajor , _SG16 , char , int , 8 , 64 , VNNI_TX , , 8 )
911921
912922/* PackedB load i8 SG16 for sub group size 32*/
923+ DEFINE_LOAD (PackedB_ColumnMajor , _SG16 , char , int , 8 , 64 , COL_MAJOR , , 4 )
913924DEFINE_LOAD (PackedB_PackedB , _SG16 , char , int , 8 , 64 , ROW_MAJOR , , 4 )
914925
915926/* B load tf32 SG16 */
0 commit comments