Skip to content

Commit ab814f2

Browse files
YuriPlyakhinigcbot
authored andcommitted
SYCL Joint Matrix Col Major int8 load A 8x32 B 32x16
Implement SYCL Joint Matrix Col Major int8 load A 8x32 B 32x16 for sub group size 32
1 parent 472af0a commit ab814f2

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -690,12 +690,17 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
690690
return; \
691691
}
692692

693-
#define DEFINE_LOAD_SCALAR_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, order, WI_rows) \
694-
contrib_type *ptr = (contrib_type *)mem; \
693+
#define DEFINE_LOAD_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, order, WI_rows) \
695694
int slid = get_sub_group_local_id(); \
696695
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
697-
long packed_stride = stride / pack_factor; \
698696
int sg_cols = K / pack_factor; \
697+
if (_##layout == _PackedA_ColumnMajor && elem_bitwidth == 8 && contrib_bitwidth == 16) { \
698+
for (int i = 0; i < M; i++) \
699+
dst[i] = mem[((slid * pack_factor) % sg_size) * stride + (slid / sg_cols) + (i & ~1) + (i % pack_factor) * stride]; \
700+
return; \
701+
} \
702+
contrib_type *ptr = (contrib_type *)mem; \
703+
long packed_stride = stride / pack_factor; \
699704
__private contrib_type *wi_contrib = (__private contrib_type *)dst; \
700705
if (order == _VNNI_TX) { \
701706
GATHER_LOAD_PACK_##elem_bitwidth(element_type, M, wi_contrib, slid, stride) \
@@ -727,20 +732,20 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
727732
} else { \
728733
DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_LOCAL) \
729734
} \
730-
DEFINE_LOAD_SCALAR_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
735+
DEFINE_LOAD_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
731736
}
732737
#define DEFINE_LOAD_IMPL_AS_LOCAL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
733738
INLINE void MANGLE_LOAD_NAME_AS_LOCAL(layout, sg, elem_bitwidth, shape, WI_rows) (__private char *dst, char *mem, long stride, int cacheOpt) { \
734739
int sg_size = get_sub_group_size(); \
735740
DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_LOCAL) \
736-
DEFINE_LOAD_SCALAR_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
741+
DEFINE_LOAD_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
737742
}
738743
#define DEFINE_LOAD_IMPL_AS_GLOBAL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
739744
INLINE void MANGLE_LOAD_NAME_AS_GLOBAL(layout, sg, elem_bitwidth, shape, WI_rows) (__private char *dst, char *mem, long stride, int cacheOpt) { \
740745
int sg_size = get_sub_group_size(); \
741746
DEFINE_LOAD_BLOCK2D_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows) \
742747
DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, WI_rows, AS_GLOBAL) \
743-
DEFINE_LOAD_SCALAR_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
748+
DEFINE_LOAD_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
744749
}
745750

746751
#define DEFINE_LOAD_CHECKED_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, M, K, shape, order, WI_rows) \
@@ -836,13 +841,18 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR,
836841

837842
DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 32, ROW_MAJOR, _us, 2)
838843

839-
// This matrix is represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
844+
// PackedA Column Major -----
845+
//
846+
// These matrices are represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
840847
// so, contrib type is `int` here and we read <4xi32> from memory, but then we use it as <8xi16>
841848
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, short, int, 8, 16, COL_MAJOR, , 8)
842849
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, char, int, 8, 32, COL_MAJOR, , 8)
843850

844-
// PackedA load i16 SG16 subgroup size 32
851+
// Sub-group size 32
845852
DEFINE_LOAD(PackedA_ColumnMajor, _SG16, short, short, 8, 16, COL_MAJOR, , 4)
853+
DEFINE_LOAD(PackedA_ColumnMajor, _SG16, char, short, 8, 32, COL_MAJOR, , 4)
854+
//
855+
// end of PackedA Column Major -----
846856

847857
/* PackedA load i16 SG16 for sub group size = 32*/
848858
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 8, 16, ROW_MAJOR, _us, 4)
@@ -910,6 +920,7 @@ DEFINE_LOAD_AND_CHECKED(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR,
910920
DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, char, int, 8, 64, VNNI_TX, , 8)
911921

912922
/* PackedB load i8 SG16 for sub group size 32*/
923+
DEFINE_LOAD(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 4)
913924
DEFINE_LOAD(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR, , 4)
914925

915926
/* B load tf32 SG16 */

0 commit comments

Comments
 (0)