@@ -794,7 +794,7 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
794
794
// R - number of rows
795
795
// C - number of columns
796
796
// VF - VNNI Factor
797
- #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
797
+ #define DEFINE_GET_COORD_ROWPACKED (layout , sg , elem_bitwidth , contrib_bitwidth , R , C , VF ) \
798
798
INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
799
799
int sg_size = get_sub_group_size(); \
800
800
int wi_id = get_sub_group_local_id(); \
@@ -807,32 +807,67 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
807
807
return result; \
808
808
}
809
809
810
- // ------ PVC -------
811
- // layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF
812
- //int8
813
- DEFINE_GET_COORD (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
814
- DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 32 , 16 , 4 )
810
+ #define DEFINE_GET_COORD (layout , sg , elem_bitwidth , R , C , slices ) \
811
+ INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
812
+ int sg_size = get_sub_group_size(); \
813
+ int wi_id = get_sub_group_local_id(); \
814
+ int elems_per_slice = (R * C / sg_size) / slices; \
815
+ int slice_cols = (C / slices); \
816
+ int sg_cols_per_wi = slice_cols / sg_size; \
817
+ int row = (index % elems_per_slice) / sg_cols_per_wi; \
818
+ int col = wi_id + ((index % elems_per_slice) % sg_cols_per_wi) * sg_size + (index / elems_per_slice * slice_cols); \
819
+ int2 result = (int2)(row, col); \
820
+ return result; \
821
+ }
815
822
816
- //bfloat16
817
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 8 , 16 , 1 )
818
- DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 16 , 1 )
819
- DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 16 , 16 , 2 )
823
+ // ------ PVC -------
824
+ // DEFINE_GET_COORD_ROWPACKED(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF)
825
+ // DEFINE_GET_COORD(layout, sg, elem_bitwidth, R, C, slices)
826
+ // int8
827
+ DEFINE_GET_COORD_ROWPACKED (PackedA , _SG16 , 8 , 16 , 8 , 32 , 1 )
828
+ DEFINE_GET_COORD (PackedB , _SG16 , 8 , 32 , 16 , 1 )
829
+
830
+ // 16bit A
831
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 16 , 1 )
832
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 8 , 16 , 1 )
833
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 16 , 16 , 1 )
834
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 1 , 32 , 1 )
835
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 16 , 1 )
836
+ DEFINE_GET_COORD (PackedA , _SG16 , 16 , 32 , 32 , 2 )
837
+
838
+ // 16bit PackedB
839
+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 16 , 1 )
840
+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 16 , 64 , 4 )
841
+ DEFINE_GET_COORD (PackedB , _SG16 , 16 , 32 , 64 , 4 )
842
+
843
+ // 16bit Row_major B
844
+ DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 16 , 1 )
845
+ DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 16 , 64 , 4 )
846
+ DEFINE_GET_COORD (PackedB_RowMajor , _SG16 , 16 , 32 , 64 , 4 )
820
847
821
848
// Accumulator
822
- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 8 , 16 , 1 )
823
- DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 16 , 16 , 1 )
849
+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 8 , 16 , 1 )
850
+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 16 , 16 , 1 )
851
+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 32 , 64 , 4 )
852
+ DEFINE_GET_COORD (Accumulator , _SG16 , 32 , 1 , 64 , 4 )
853
+
854
+ // Accumulator 16bit
855
+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 8 , 16 , 1 )
856
+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 16 , 16 , 1 )
857
+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 32 , 64 , 2 )
858
+ DEFINE_GET_COORD (Accumulator , _SG16 , 16 , 1 , 64 , 4 )
824
859
825
860
// --------- XMX8 ------------
826
861
//int8
827
- DEFINE_GET_COORD (PackedA , , 8 , 32 , 8 , 32 , 1 )
828
- DEFINE_GET_COORD (PackedB , , 8 , 32 , 32 , 8 , 4 )
862
+ DEFINE_GET_COORD_ROWPACKED (PackedA , , 8 , 32 , 8 , 32 , 1 )
863
+ DEFINE_GET_COORD_ROWPACKED (PackedB , , 8 , 32 , 32 , 8 , 4 )
829
864
830
865
//bfloat16
831
- DEFINE_GET_COORD (PackedA , , 16 , 32 , 8 , 16 , 1 )
832
- DEFINE_GET_COORD (PackedB , , 16 , 32 , 16 , 8 , 2 )
866
+ DEFINE_GET_COORD_ROWPACKED (PackedA , , 16 , 32 , 8 , 16 , 1 )
867
+ DEFINE_GET_COORD_ROWPACKED (PackedB , , 16 , 32 , 16 , 8 , 2 )
833
868
834
869
// Accumulator
835
- DEFINE_GET_COORD (Accumulator , , 32 , 32 , 8 , 8 , 1 )
870
+ DEFINE_GET_COORD_ROWPACKED (Accumulator , , 32 , 32 , 8 , 8 , 1 )
836
871
837
872
/* experimental large slice support: */
838
873
0 commit comments