@@ -794,7 +794,7 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
794794// R - number of rows
795795// C - number of columns
796796// VF - VNNI Factor
797- #define DEFINE_GET_COORD_ROWPACKED (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
797+ #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
798798 INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
799799 int sg_size = get_sub_group_size(); \
800800 int wi_id = get_sub_group_local_id(); \
@@ -807,67 +807,32 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
807807 return result; \
808808 }
809809
810- #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , R , C , slices ) \
811- INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
812- int sg_size = get_sub_group_size(); \
813- int wi_id = get_sub_group_local_id(); \
814- int elems_per_slice = (R * C / sg_size) / slices; \
815- int slice_cols = (C / slices); \
816- int sg_cols_per_wi = slice_cols / sg_size; \
817- int row = (index % elems_per_slice) / sg_cols_per_wi; \
818- int col = wi_id + ((index % elems_per_slice) % sg_cols_per_wi) * sg_size + (index / elems_per_slice * slice_cols); \
819- int2 result = (int2)(row, col); \
820- return result; \
821- }
822-
823810// ------ PVC -------
824- // DEFINE_GET_COORD_ROWPACKED(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF)
825- // DEFINE_GET_COORD(layout, sg, elem_bitwidth, R, C, slices)
826- // int8
827- DEFINE_GET_COORD_ROWPACKED (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
828- DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 16 , 1 )
829-
830- // 16bit A
831- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 16 , 1 )
832- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 8 , 16 , 1 )
833- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 1 )
834- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 32 , 1 )
835- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 16 , 1 )
836- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 32 , 2 )
837-
838- // 16bit PackedB
839- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 16 , 1 )
840- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 64 , 4 )
841- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 64 , 4 )
842-
843- // 16bit Row_major B
844- DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 16 , 1 )
845- DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 64 , 4 )
846- DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 32 , 64 , 4 )
811+ // layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF
812+ //int8
813+ DEFINE_GET_COORD (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
814+ DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 32 , 16 , 4 )
847815
848- // Accumulator
849- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 8 , 16 , 1 )
850- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 16 , 16 , 1 )
851- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 64 , 4 )
852- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 1 , 64 , 4 )
816+ //bfloat16
817+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 8 , 16 , 1 )
818+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 16 , 1 )
819+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 16 , 16 , 2 )
853820
854- // Accumulator 16bit
855- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 8 , 16 , 1 )
856- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 16 , 16 , 1 )
857- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 32 , 64 , 2 )
858- DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 1 , 64 , 4 )
821+ // Accumulator
822+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 8 , 16 , 1 )
823+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 16 , 16 , 1 )
859824
860825// --------- XMX8 ------------
861826//int8
862- DEFINE_GET_COORD_ROWPACKED (PackedA , , 8 , 32 , 8 , 32 , 1 )
863- DEFINE_GET_COORD_ROWPACKED (PackedB , , 8 , 32 , 32 , 8 , 4 )
827+ DEFINE_GET_COORD (PackedA , , 8 , 32 , 8 , 32 , 1 )
828+ DEFINE_GET_COORD (PackedB , , 8 , 32 , 32 , 8 , 4 )
864829
865830//bfloat16
866- DEFINE_GET_COORD_ROWPACKED (PackedA , , 16 , 32 , 8 , 16 , 1 )
867- DEFINE_GET_COORD_ROWPACKED (PackedB , , 16 , 32 , 16 , 8 , 2 )
831+ DEFINE_GET_COORD (PackedA , , 16 , 32 , 8 , 16 , 1 )
832+ DEFINE_GET_COORD (PackedB , , 16 , 32 , 16 , 8 , 2 )
868833
869834// Accumulator
870- DEFINE_GET_COORD_ROWPACKED (Accumulator , , 32 , 32 , 8 , 8 , 1 )
835+ DEFINE_GET_COORD (Accumulator , , 32 , 32 , 8 , 8 , 1 )
871836
872837/* experimental large slice support: */
873838
0 commit comments