Skip to content

Commit ed88b63

Browse files
ggojskaigcbot
authored andcommitted
Add get_coord formula for get_coord for 16 and 32 bit datatypes.
Platforms: All Keywords: Feature Related-to: GSD-11139 Resolves:
1 parent d85d0be commit ed88b63

File tree

2 files changed

+58
-18
lines changed

2 files changed

+58
-18
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/Matrix/IBiF_matrix.cl

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
794794
// R - number of rows
795795
// C - number of columns
796796
// VF - VNNI Factor
797-
#define DEFINE_GET_COORD(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF) \
797+
#define DEFINE_GET_COORD_ROWPACKED(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF) \
798798
INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
799799
int sg_size = get_sub_group_size(); \
800800
int wi_id = get_sub_group_local_id(); \
@@ -807,32 +807,67 @@ The storage in memory will be: 0 0 1 1 2 2 ... 7 7
807807
return result; \
808808
}
809809

810-
// ------ PVC -------
811-
// layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF
812-
//int8
813-
DEFINE_GET_COORD(PackedA, _SG16, 8, 16, 8, 32, 1)
814-
DEFINE_GET_COORD(PackedB, _SG16, 8, 32, 32, 16, 4)
810+
#define DEFINE_GET_COORD(layout, sg, elem_bitwidth, R, C, slices) \
811+
INLINE int2 MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) (int index) { \
812+
int sg_size = get_sub_group_size(); \
813+
int wi_id = get_sub_group_local_id(); \
814+
int elems_per_slice = (R * C / sg_size) / slices; \
815+
int slice_cols = (C / slices); \
816+
int sg_cols_per_wi = slice_cols / sg_size; \
817+
int row = (index % elems_per_slice) / sg_cols_per_wi; \
818+
int col = wi_id + ((index % elems_per_slice) % sg_cols_per_wi) * sg_size + (index / elems_per_slice * slice_cols); \
819+
int2 result = (int2)(row, col); \
820+
return result; \
821+
}
815822

816-
//bfloat16
817-
DEFINE_GET_COORD(PackedA, _SG16, 16, 16, 8, 16, 1)
818-
DEFINE_GET_COORD(PackedA, _SG16, 16, 16, 16, 16, 1)
819-
DEFINE_GET_COORD(PackedB, _SG16, 16, 32, 16, 16, 2)
823+
// ------ PVC -------
824+
// DEFINE_GET_COORD_ROWPACKED(layout, sg, elem_bitwidth, contrib_bitwidth, R, C, VF)
825+
// DEFINE_GET_COORD(layout, sg, elem_bitwidth, R, C, slices)
826+
// int8
827+
DEFINE_GET_COORD_ROWPACKED(PackedA, _SG16, 8, 16, 8, 32, 1)
828+
DEFINE_GET_COORD(PackedB, _SG16, 8, 32, 16, 1)
829+
830+
// 16bit A
831+
DEFINE_GET_COORD(PackedA, _SG16, 16, 1, 16, 1)
832+
DEFINE_GET_COORD(PackedA, _SG16, 16, 8, 16, 1)
833+
DEFINE_GET_COORD(PackedA, _SG16, 16, 16, 16, 1)
834+
DEFINE_GET_COORD(PackedA, _SG16, 16, 1, 32, 1)
835+
DEFINE_GET_COORD(PackedA, _SG16, 16, 32, 16, 1)
836+
DEFINE_GET_COORD(PackedA, _SG16, 16, 32, 32, 2)
837+
838+
// 16bit PackedB
839+
DEFINE_GET_COORD(PackedB, _SG16, 16, 16, 16, 1)
840+
DEFINE_GET_COORD(PackedB, _SG16, 16, 16, 64, 4)
841+
DEFINE_GET_COORD(PackedB, _SG16, 16, 32, 64, 4)
842+
843+
// 16bit Row_major B
844+
DEFINE_GET_COORD(PackedB_RowMajor, _SG16, 16, 16, 16, 1)
845+
DEFINE_GET_COORD(PackedB_RowMajor, _SG16, 16, 16, 64, 4)
846+
DEFINE_GET_COORD(PackedB_RowMajor, _SG16, 16, 32, 64, 4)
820847

821848
// Accumulator
822-
DEFINE_GET_COORD(Accumulator, _SG16, 32, 32, 8, 16, 1)
823-
DEFINE_GET_COORD(Accumulator, _SG16, 32, 32, 16, 16, 1)
849+
DEFINE_GET_COORD(Accumulator, _SG16, 32, 8, 16, 1)
850+
DEFINE_GET_COORD(Accumulator, _SG16, 32, 16, 16, 1)
851+
DEFINE_GET_COORD(Accumulator, _SG16, 32, 32, 64, 4)
852+
DEFINE_GET_COORD(Accumulator, _SG16, 32, 1, 64, 4)
853+
854+
// Accumulator 16bit
855+
DEFINE_GET_COORD(Accumulator, _SG16, 16, 8, 16, 1)
856+
DEFINE_GET_COORD(Accumulator, _SG16, 16, 16, 16, 1)
857+
DEFINE_GET_COORD(Accumulator, _SG16, 16, 32, 64, 2)
858+
DEFINE_GET_COORD(Accumulator, _SG16, 16, 1, 64, 4)
824859

825860
// --------- XMX8 ------------
826861
//int8
827-
DEFINE_GET_COORD(PackedA, , 8, 32, 8, 32, 1)
828-
DEFINE_GET_COORD(PackedB, , 8, 32, 32, 8, 4)
862+
DEFINE_GET_COORD_ROWPACKED(PackedA, , 8, 32, 8, 32, 1)
863+
DEFINE_GET_COORD_ROWPACKED(PackedB, , 8, 32, 32, 8, 4)
829864

830865
//bfloat16
831-
DEFINE_GET_COORD(PackedA, , 16, 32, 8, 16, 1)
832-
DEFINE_GET_COORD(PackedB, , 16, 32, 16, 8, 2)
866+
DEFINE_GET_COORD_ROWPACKED(PackedA, , 16, 32, 8, 16, 1)
867+
DEFINE_GET_COORD_ROWPACKED(PackedB, , 16, 32, 16, 8, 2)
833868

834869
// Accumulator
835-
DEFINE_GET_COORD(Accumulator, , 32, 32, 8, 8, 1)
870+
DEFINE_GET_COORD_ROWPACKED(Accumulator, , 32, 32, 8, 8, 1)
836871

837872
/* experimental large slice support: */
838873

IGC/BiFModule/Languages/OpenCL/PreRelease/Matrix/IBiF_matrix_generator.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,12 @@ ImplementLargeLoadBase(MatrixSpec spec, AddrSpace addr, int numLoads, bool isChe
10581058
s += "__private char *dst1 = dst + WiRowsPerLoad * ContribByteWidth;\n";
10591059
// Prepare mem (source) pointers
10601060
s += "char *mem0 = mem;\n";
1061-
s += "char *mem1 = mem + 16 * ElemByteWidth;\n";
1061+
// 32bit Accumulator 32x64 is sliced into 4* 32x16
1062+
// 16bit Accumulator 32x64 is sliced into 2* 32x32
1063+
if (spec.BitWidth == 16 && spec.Layout == Layout_Accumulator_RowMajor)
1064+
s += "char *mem1 = mem + 32 * ElemByteWidth;\n";
1065+
else
1066+
s += "char *mem1 = mem + 16 * ElemByteWidth;\n";
10621067
// Call load sub-functions
10631068
s += "LoadFunc(dst0, mem0, stride, cacheOpt);\n";
10641069
s += "LoadFunc(dst1, mem1, stride, cacheOpt);\n";

0 commit comments

Comments
 (0)