Skip to content

Commit 9239d02

Browse files
YuriPlyakhinigcbot
authored andcommitted
SYCL Joint Matrix bfloat16 B row major load SG32
Enables SYCL Joint Matrix bfloat16 B row major load for 16x16 for sub-group size 32
1 parent 8e40147 commit 9239d02

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
421421

422422
#define DEFINE_BLOCK2D_RW_NAME(rw, tx, contrib_bitwidth, WI_rows, tile_height, tile_width) __builtin_IB_subgroup_block_##rw##_flat_cacheopts##tx##_u##contrib_bitwidth##_wi##WI_rows##_m##tile_height##k##tile_width##v1
423423
#define DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, WI_rows, tile_height, tile_width) __builtin_IB_subgroup_block_read_flat_cacheopts_transpose_u##contrib_bitwidth##_wi##WI_rows##_m##tile_height##_k##tile_width
424-
#define DEFINE_BLOCK2D_VNNI_NAME(contrib_bitwidth, tile_height) __builtin_IB_subgroup_block_read_flat_cacheopts_transform_u##contrib_bitwidth##_k##tile_height // tile_width = sub group size (16)
424+
#define DEFINE_BLOCK2D_VNNI_NAME(contrib_bitwidth, WI_rows, tile_height, tile_width) __builtin_IB_subgroup_block_read_flat_cacheopts_transform_u##contrib_bitwidth##_wi##WI_rows##_##k##tile_height##n##tile_width
425425

426426
/* For platforms without SG16 JointMatrix support block2d is not available. The
427427
* implementation remains empty, will fallthrough to vector implementation. */
@@ -539,8 +539,8 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
539539
int height = orig_M - 1; /* row count */ \
540540
long x = (offset - baseoffset) / (sizeof (element_type)); /* in elements */ \
541541
int2 coords = (int2)(x, 0); \
542-
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
543-
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
542+
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K)(long, int, int, int, int2, int); \
543+
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K)(baseoffset, width, height, pitch, coords, cacheOpt); \
544544
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
545545
return; \
546546
}
@@ -591,8 +591,8 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
591591
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
592592
int height_size = height - 1; \
593593
int2 coords = (int2)(x, y); \
594-
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
595-
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
594+
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K)(long, int, int, int, int2, int); \
595+
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K)(offset, width_size, height_size, pitch, coords, cacheOpt); \
596596
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
597597
return;
598598

@@ -902,6 +902,7 @@ DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, short, int, 16, 32, VNNI_TX,
902902
/* PackedB load i16 SG16 for sub group size = 32*/
903903
DEFINE_LOAD(PackedB_ColumnMajor, _SG16, short, int, 8, 32, COL_MAJOR, , 4)
904904
DEFINE_LOAD(PackedB_PackedB, _SG16, short, int, 8, 32, ROW_MAJOR, , 4)
905+
DEFINE_LOAD(PackedB_RowMajor, _SG16, short, int, 8, 32, VNNI_TX, , 4)
905906

906907
/* PackedB load i8 SG16*/
907908
DEFINE_LOAD_AND_CHECKED(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 8)

IGC/Compiler/Optimizer/OpenCLPasses/LSCFuncs/LSCFuncsResolution.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,12 +1032,20 @@ Instruction* LSCFuncsResolution::CreateSubGroup2DBlockOperation(llvm::CallInst&
10321032
else if (isVnniTransform && !isTranspose)
10331033
{
10341034
numBlocksV = 1;
1035+
tileWidth = subGrpSize;
10351036

10361037
if (elemSize == 8)
10371038
{
10381039
bool is32Height = funcName.consume_front("_k32");
10391040
IGC_ASSERT_MESSAGE(is32Height, "Only k32 is supported for 8 bit element size, at the moment.");
10401041

1042+
// If sub-group size is 32, we still may want to use width = 16
1043+
// __builtin_IB_subgroup_block_read_flat_cacheopts_transform_u8_wi8_k32n16
1044+
if (funcName.consume_front("n16"))
1045+
{
1046+
tileWidth = 16;
1047+
}
1048+
10411049
// __builtin_IB_subgroup_block_read_flat_transform_u8_k32v2
10421050
if (funcName.consume_front("v2"))
10431051
{
@@ -1070,15 +1078,20 @@ Instruction* LSCFuncsResolution::CreateSubGroup2DBlockOperation(llvm::CallInst&
10701078
return nullptr;
10711079
}
10721080

1081+
// If sub-group size is 32, we still may want to use width = 16
1082+
// __builtin_IB_subgroup_block_read_flat_transform_u16_k16n16
1083+
if (funcName.consume_front("n16"))
1084+
{
1085+
tileWidth = 16;
1086+
}
1087+
10731088
// __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2
10741089
// __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2
10751090
if (funcName.consume_front("v2"))
10761091
{
10771092
numBlocksV = 2;
10781093
}
10791094
}
1080-
1081-
tileWidth = subGrpSize;
10821095
}
10831096
else
10841097
{

0 commit comments

Comments
 (0)