Skip to content

Commit ef066c3

Browse files
YuriPlyakhinigcbot
authored andcommitted
SYCL Joint Matrix add transpose A and B (8x16x16 for PVC) load
Add support for transpose A 8x16 bfloat16 and B 16x16 bfloat16
1 parent 23d5ca9 commit ef066c3

File tree

3 files changed

+99
-48
lines changed

3 files changed

+99
-48
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ SPDX-License-Identifier: MIT
4646
#define _PackedB_PackedB 3
4747
#define _Accumulator_RowMajor 4
4848
#define _Accumulator_ColumnMajor 5
49+
#define _PackedA_ColumnMajor 6
4950

5051
#define ATTRIBUTE_AS_GENERIC __global /* the branch using this will be dead,
5152
however we still need a valid address
@@ -179,9 +180,13 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
179180
#define MATH_8_DIV_4 2
180181
#define MATH_8_DIV_2 4
181182
#define MATH_8_DIV_1 8
183+
#define MATH_7_DIV_1 7
184+
#define MATH_6_DIV_1 6
185+
#define MATH_5_DIV_1 5
182186
#define MATH_4_DIV_4 1
183187
#define MATH_4_DIV_2 2
184188
#define MATH_4_DIV_1 4
189+
#define MATH_3_DIV_1 3
185190
#define MATH_2_DIV_2 1
186191
#define MATH_2_DIV_1 2
187192
#define MATH_1_DIV_1 1
@@ -225,6 +230,7 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
225230
#define SHAPE_CONCAT_VNNI(M, K, vnni_factor) SHAPE_CONCAT_VNNI__(MATH_MUL(M, vnni_factor), MATH_DIV(K, vnni_factor))
226231

227232
#define SHAPE_PackedA_RowMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT(M, K)
233+
#define SHAPE_PackedA_ColumnMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT(M, K)
228234
#define SHAPE_PackedB_RowMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
229235
#define SHAPE_PackedB_ColumnMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
230236
#define SHAPE_PackedB_PackedB( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
@@ -407,13 +413,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
407413

408414
/* For platforms without SG16 JointMatrix support block2d is not available. The
409415
* implementation remains empty, will fallthrough to vector implementation. */
410-
#define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
416+
#define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11) \
411417
/* not supported, fallthrough */
412-
#define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
418+
#define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11) \
413419
/* not supported, fallthrough */
414-
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
420+
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11) \
415421
/* not supported, fallthrough */
416-
#define IMPLEMENT_BLOCK2D_STORE(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K) \
422+
#define IMPLEMENT_BLOCK2D_STORE(p1, p2, p3, p4, p5, p6, p7, p8) \
417423
/* not supported, fallthrough */
418424

419425
// contrib_K - calculated in BLOCK2D loads; contrib_K = K/(contrib_bitwidth/elem_bitwidth);
@@ -422,7 +428,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
422428

423429
#define MAX_ROW_BYTES_2D_BLOCK_LOAD 64 // maximum per row size in bytes supported by 2D block load
424430

425-
#define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
431+
// M and K are the original tile dimensions for A and C.
432+
// M and K are the VNNI-ed dimensions for matrix B (hence the are also aliased as M_VNNI and K_VNNI in some places)
433+
// Hence, we also have orig_M, which can be used as non-VNNI'ed M dimension of matrix B
434+
// When smaller elements are packed into bigger types, contrib_M and contrib_K represent M and K dimensions of the packed matrix
435+
// Not all values (orig_M, contrib_M, contrib_K) are valid for all different layouts and macros, so before using any of it, please, check,
436+
// if it's valid for the current configuration
437+
#define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
426438
if (contrib_K*sizeof(contrib_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
427439
long offset = as_long(mem); \
428440
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
@@ -437,7 +449,8 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
437449
return; \
438450
}
439451

440-
#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
452+
// 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose
453+
#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
441454
if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
442455
long offset = as_long(mem); \
443456
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
@@ -446,29 +459,45 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
446459
int height = 16 - 1; /* taken from SG16 */ \
447460
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
448461
int2 coords = (int2)(x, 0); \
449-
/* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
450-
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(long, int, int, int, int2, int); \
451-
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
452-
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
453-
return; \
462+
\
463+
if (elem_bitwidth == 32) { \
464+
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
465+
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
466+
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
467+
return; \
468+
} \
469+
\
470+
if (elem_bitwidth == 16 && _##layout == _PackedA_ColumnMajor) { \
471+
OUT_VEC##contrib_M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(long, int, int, int, int2, int); \
472+
OUT_VEC##contrib_M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
473+
*(__private OUT_VEC##WI_rows(u##element_type) *)dst = *(__private OUT_VEC##WI_rows(u##element_type) *)&res; \
474+
return; \
475+
} \
476+
\
477+
if (elem_bitwidth == 16 && _##layout == _PackedB_ColumnMajor) { \
478+
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
479+
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
480+
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
481+
return; \
482+
} \
454483
}
455484

456-
#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
485+
#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M_VNNI, K_VNNI, WI_rows, orig_M, contrib_M, contrib_K) \
457486
if (contrib_K*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
458487
long offset = as_long(mem); \
459488
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
460489
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
461490
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
462-
int height = contrib_M - 1; /* row count */ \
491+
int height = orig_M - 1; /* row count */ \
463492
long x = (offset - baseoffset) / (sizeof (element_type)); /* in elements */ \
464493
int2 coords = (int2)(x, 0); \
465-
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(long, int, int, int, int2, int); \
466-
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
467-
*(__private OUT_VEC##M(u##contrib_type) *)dst = res; \
494+
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
495+
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
496+
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
468497
return; \
469498
}
470499

471-
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
500+
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
472501
long offset = as_long(mem); \
473502
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
474503
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
@@ -480,7 +509,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
480509
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
481510
return;
482511

483-
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
512+
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
484513
long offset = as_long(mem); \
485514
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
486515
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
@@ -492,25 +521,29 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
492521
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
493522
return;
494523

495-
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
524+
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M_VNNI, K_VNNI, WI_rows, orig_M, contrib_M, contrib_K) \
496525
long offset = as_long(mem); \
497526
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
498527
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
499528
int height_size = height - 1; \
500529
int2 coords = (int2)(x, y); \
501-
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(long, int, int, int, int2, int); \
502-
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
530+
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
531+
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
503532
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
504533
return;
505534

506-
#define IMPLEMENT_BLOCK2D_LOAD__(checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows) \
507-
IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
508-
M, K, WI_rows, \
535+
#define IMPLEMENT_BLOCK2D_LOAD___(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
536+
IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
537+
M, K, WI_rows, orig_M, contrib_M, contrib_K)
538+
539+
#define IMPLEMENT_BLOCK2D_LOAD__(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows) \
540+
IMPLEMENT_BLOCK2D_LOAD___(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, \
509541
MATH_MUL(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
542+
MATH_DIV(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
510543
MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
511544

512-
#define IMPLEMENT_BLOCK2D_LOAD(checked, sg, order, element_type, contrib_type, M, K, WI_rows) \
513-
IMPLEMENT_BLOCK2D_LOAD__(checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
545+
#define IMPLEMENT_BLOCK2D_LOAD(layout, checked, sg, order, element_type, contrib_type, M, K, WI_rows) \
546+
IMPLEMENT_BLOCK2D_LOAD__(layout, checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
514547
M, K, WI_rows)
515548

516549
#define IMPLEMENT_BLOCK2D_STORE_SG16(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K) \
@@ -557,13 +590,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
557590
#define DEFINE_LOAD_BLOCK2D_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
558591
if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL \
559592
&& (M == 1 || M == 2 || M == 4 || M == 8 || M == 16 || M == 32) \
560-
&& (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32 && WI_rows == M)) \
593+
&& (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && WI_rows == M)) \
561594
) { \
562-
IMPLEMENT_BLOCK2D_LOAD(, sg, order##_, element_type, contrib_type, M, K, WI_rows) \
595+
IMPLEMENT_BLOCK2D_LOAD(layout, , sg, order##_, element_type, contrib_type, M, K, WI_rows) \
563596
}
564597

565598
#define DEFINE_LOAD_CHECKED_BLOCK2D_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
566-
IMPLEMENT_BLOCK2D_LOAD(_CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
599+
IMPLEMENT_BLOCK2D_LOAD(layout, _CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
567600

568601
#define DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, address_space) \
569602
if (WI_rows >= M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL \
@@ -735,6 +768,10 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR,
735768

736769
DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 32, ROW_MAJOR, _us, 2)
737770

771+
// This matrix is represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
772+
// so, contrib type is `int` here and we read <4xi32> from memory, but then we use it as <8xi16>
773+
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, short, int, 8, 16, COL_MAJOR, , 8)
774+
738775
/* PackedA load i16 SG16 for sub group size = 32*/
739776
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 8, 16, ROW_MAJOR, _us, 4)
740777
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 7, 16, ROW_MAJOR, _us, 4)

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ static SupportedParams getSupportedParams(const JointMatrixTypeDescription *desc
617617
params.columns = maxSliceBitWidth / desc->bitWidth;
618618
params.bitWidth |= 16;
619619
params.layouts = 1 << LayoutRowMajor;
620+
params.layouts |= 1 << LayoutColumnMajor;
620621
} else if (desc->layout == LayoutPackedB) {
621622
params.rows = maxSliceBitWidth / desc->bitWidth;
622623
params.columns = useSG16 ? 16 : 8;

0 commit comments

Comments
 (0)