@@ -794,7 +794,7 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
794794// R - number of rows
795795// C - number of columns
796796// VF - VNNI Factor
797- #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
797+ #define DEFINE_GET_COORD_ROWPACKED (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
798798 INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
799799 int sg_size = get_sub_group_size(); \
800800 int wi_id = get_sub_group_local_id(); \
@@ -807,32 +807,67 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
807807 return result; \
808808 }
809809
810- // ------ PVC -------
811- // layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF
812- //int8
813- DEFINE_GET_COORD (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
814- DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 32 , 16 , 4 )
810+ #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , R , C , slices ) \
811+ INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
812+ int sg_size = get_sub_group_size(); \
813+ int wi_id = get_sub_group_local_id(); \
814+ int elems_per_slice = (R * C / sg_size) / slices; \
815+ int slice_cols = (C / slices); \
816+ int sg_cols_per_wi = slice_cols / sg_size; \
817+ int row = (index % elems_per_slice) / sg_cols_per_wi; \
818+ int col = wi_id + ((index % elems_per_slice) % sg_cols_per_wi) * sg_size + (index / elems_per_slice * slice_cols); \
819+ int2 result = (int2)(row, col); \
820+ return result; \
821+ }
815822
816- //bfloat16
817- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 8 , 16 , 1 )
818- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 16 , 1 )
819- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 16 , 16 , 2 )
823+ // ------ PVC -------
824+ // DEFINE_GET_COORD_ROWPACKED(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF)
825+ // DEFINE_GET_COORD(layout, sg, elem_bitwidth, R, C, slices)
826+ // int8
827+ DEFINE_GET_COORD_ROWPACKED (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
828+ DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 16 , 1 )
829+
830+ // 16bit A
831+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 16 , 1 )
832+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 8 , 16 , 1 )
833+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 1 )
834+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 32 , 1 )
835+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 16 , 1 )
836+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 32 , 2 )
837+
838+ // 16bit PackedB
839+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 16 , 1 )
840+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 64 , 4 )
841+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 64 , 4 )
842+
843+ // 16bit Row_major B
844+ DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 16 , 1 )
845+ DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 64 , 4 )
846+ DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 32 , 64 , 4 )
820847
821848// Accumulator
822- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 8 , 16 , 1 )
823- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 16 , 16 , 1 )
849+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 8 , 16 , 1 )
850+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 16 , 16 , 1 )
851+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 64 , 4 )
852+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 1 , 64 , 4 )
853+
854+ // Accumulator 16bit
855+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 8 , 16 , 1 )
856+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 16 , 16 , 1 )
857+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 32 , 64 , 2 )
858+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 1 , 64 , 4 )
824859
825860// --------- XMX8 ------------
826861//int8
827- DEFINE_GET_COORD (PackedA , , 8 , 32 , 8 , 32 , 1 )
828- DEFINE_GET_COORD (PackedB , , 8 , 32 , 32 , 8 , 4 )
862+ DEFINE_GET_COORD_ROWPACKED (PackedA , , 8 , 32 , 8 , 32 , 1 )
863+ DEFINE_GET_COORD_ROWPACKED (PackedB , , 8 , 32 , 32 , 8 , 4 )
829864
830865//bfloat16
831- DEFINE_GET_COORD (PackedA , , 16 , 32 , 8 , 16 , 1 )
832- DEFINE_GET_COORD (PackedB , , 16 , 32 , 16 , 8 , 2 )
866+ DEFINE_GET_COORD_ROWPACKED (PackedA , , 16 , 32 , 8 , 16 , 1 )
867+ DEFINE_GET_COORD_ROWPACKED (PackedB , , 16 , 32 , 16 , 8 , 2 )
833868
834869// Accumulator
835- DEFINE_GET_COORD (Accumulator , , 32 , 32 , 8 , 8 , 1 )
870+ DEFINE_GET_COORD_ROWPACKED (Accumulator , , 32 , 32 , 8 , 8 , 1 )
836871
837872/* experimental large slice support: */
838873
0 commit comments