Skip to content

Commit aea1f76

Browse files
sys-igcigcbot
authored andcommitted
[Autobackout][FunctionalRegression]Revert of change: ef066c3: SYCL Joint Matrix add transpose A and B (8x16x16 for PVC) load
Add support for transpose A 8x16 bfloat16 and B 16x16 bfloat16
1 parent 0e8b020 commit aea1f76

File tree

3 files changed

+48
-99
lines changed

3 files changed

+48
-99
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 29 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ SPDX-License-Identifier: MIT
4646
#define _PackedB_PackedB 3
4747
#define _Accumulator_RowMajor 4
4848
#define _Accumulator_ColumnMajor 5
49-
#define _PackedA_ColumnMajor 6
5049

5150
#define ATTRIBUTE_AS_GENERIC __global /* the branch using this will be dead,
5251
however we still need a valid address
@@ -180,13 +179,9 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
180179
#define MATH_8_DIV_4 2
181180
#define MATH_8_DIV_2 4
182181
#define MATH_8_DIV_1 8
183-
#define MATH_7_DIV_1 7
184-
#define MATH_6_DIV_1 6
185-
#define MATH_5_DIV_1 5
186182
#define MATH_4_DIV_4 1
187183
#define MATH_4_DIV_2 2
188184
#define MATH_4_DIV_1 4
189-
#define MATH_3_DIV_1 3
190185
#define MATH_2_DIV_2 1
191186
#define MATH_2_DIV_1 2
192187
#define MATH_1_DIV_1 1
@@ -230,7 +225,6 @@ typedef uint __attribute__((ext_vector_type(32))) uint32;
230225
#define SHAPE_CONCAT_VNNI(M, K, vnni_factor) SHAPE_CONCAT_VNNI__(MATH_MUL(M, vnni_factor), MATH_DIV(K, vnni_factor))
231226

232227
#define SHAPE_PackedA_RowMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT(M, K)
233-
#define SHAPE_PackedA_ColumnMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT(M, K)
234228
#define SHAPE_PackedB_RowMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
235229
#define SHAPE_PackedB_ColumnMajor( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
236230
#define SHAPE_PackedB_PackedB( M, K, elem_bitwidth, contrib_bitwidth) SHAPE_CONCAT_VNNI(M, K, MATH_DIV(contrib_bitwidth, elem_bitwidth))
@@ -413,13 +407,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
413407

414408
/* For platforms without SG16 JointMatrix support block2d is not available. The
415409
* implementation remains empty, will fallthrough to vector implementation. */
416-
#define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11) \
410+
#define IMPLEMENT_BLOCK2D_LOAD_ROW_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
417411
/* not supported, fallthrough */
418-
#define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11) \
412+
#define IMPLEMENT_BLOCK2D_LOAD_COL_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
419413
/* not supported, fallthrough */
420-
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11) \
414+
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
421415
/* not supported, fallthrough */
422-
#define IMPLEMENT_BLOCK2D_STORE(p1, p2, p3, p4, p5, p6, p7, p8) \
416+
#define IMPLEMENT_BLOCK2D_STORE(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K) \
423417
/* not supported, fallthrough */
424418

425419
// contrib_K - calculated in BLOCK2D loads; contrib_K = K/(contrib_bitwidth/elem_bitwidth);
@@ -428,13 +422,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
428422

429423
#define MAX_ROW_BYTES_2D_BLOCK_LOAD 64 // maximum per row size in bytes supported by 2D block load
430424

431-
// M and K are the original tile dimensions for A and C.
432-
// M and K are the VNNI-ed dimensions for matrix B (hence the are also aliased as M_VNNI and K_VNNI in some places)
433-
// Hence, we also have orig_M, which can be used as non-VNNI'ed M dimension of matrix B
434-
// When smaller elements are packed into bigger types, contrib_M and contrib_K represent M and K dimensions of the packed matrix
435-
// Not all values (orig_M, contrib_M, contrib_K) are valid for all different layouts and macros, so before using any of it, please, check,
436-
// if it's valid for the current configuration
437-
#define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
425+
#define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
438426
if (contrib_K*sizeof(contrib_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
439427
long offset = as_long(mem); \
440428
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
@@ -449,8 +437,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
449437
return; \
450438
}
451439

452-
// 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose
453-
#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
440+
#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
454441
if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
455442
long offset = as_long(mem); \
456443
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
@@ -459,45 +446,29 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
459446
int height = 16 - 1; /* taken from SG16 */ \
460447
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
461448
int2 coords = (int2)(x, 0); \
462-
\
463-
if (elem_bitwidth == 32) { \
464-
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
465-
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
466-
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
467-
return; \
468-
} \
469-
\
470-
if (elem_bitwidth == 16 && _##layout == _PackedA_ColumnMajor) { \
471-
OUT_VEC##contrib_M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(long, int, int, int, int2, int); \
472-
OUT_VEC##contrib_M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
473-
*(__private OUT_VEC##WI_rows(u##element_type) *)dst = *(__private OUT_VEC##WI_rows(u##element_type) *)&res; \
474-
return; \
475-
} \
476-
\
477-
if (elem_bitwidth == 16 && _##layout == _PackedB_ColumnMajor) { \
478-
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(long, int, int, int, int2, int); \
479-
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
480-
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
481-
return; \
482-
} \
449+
/* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
450+
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(long, int, int, int, int2, int); \
451+
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, M)(baseoffset, width, height, pitch, coords, cacheOpt); \
452+
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
453+
return; \
483454
}
484455

485-
#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M_VNNI, K_VNNI, WI_rows, orig_M, contrib_M, contrib_K) \
456+
#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
486457
if (contrib_K*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
487458
long offset = as_long(mem); \
488459
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
489460
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
490461
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
491-
int height = orig_M - 1; /* row count */ \
462+
int height = contrib_M - 1; /* row count */ \
492463
long x = (offset - baseoffset) / (sizeof (element_type)); /* in elements */ \
493464
int2 coords = (int2)(x, 0); \
494-
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
495-
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
496-
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
465+
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(long, int, int, int, int2, int); \
466+
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
467+
*(__private OUT_VEC##M(u##contrib_type) *)dst = res; \
497468
return; \
498469
}
499470

500-
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
471+
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
501472
long offset = as_long(mem); \
502473
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
503474
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
@@ -509,7 +480,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
509480
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
510481
return;
511482

512-
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
483+
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
513484
long offset = as_long(mem); \
514485
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
515486
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
@@ -521,29 +492,25 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
521492
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
522493
return;
523494

524-
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M_VNNI, K_VNNI, WI_rows, orig_M, contrib_M, contrib_K) \
495+
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_VNNI_TX_(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_M, contrib_K) \
525496
long offset = as_long(mem); \
526497
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
527498
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
528499
int height_size = height - 1; \
529500
int2 coords = (int2)(x, y); \
530-
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
531-
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
501+
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(long, int, int, int, int2, int); \
502+
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, contrib_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
532503
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
533504
return;
534505

535-
#define IMPLEMENT_BLOCK2D_LOAD___(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, contrib_M, contrib_K) \
536-
IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
537-
M, K, WI_rows, orig_M, contrib_M, contrib_K)
538-
539-
#define IMPLEMENT_BLOCK2D_LOAD__(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows) \
540-
IMPLEMENT_BLOCK2D_LOAD___(layout, checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, \
506+
#define IMPLEMENT_BLOCK2D_LOAD__(checked, sg, order, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows) \
507+
IMPLEMENT_BLOCK2D_LOAD##checked##sg##order(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, \
508+
M, K, WI_rows, \
541509
MATH_MUL(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
542-
MATH_DIV(M, MATH_DIV(contrib_bitwidth, elem_bitwidth)), \
543510
MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
544511

545-
#define IMPLEMENT_BLOCK2D_LOAD(layout, checked, sg, order, element_type, contrib_type, M, K, WI_rows) \
546-
IMPLEMENT_BLOCK2D_LOAD__(layout, checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
512+
#define IMPLEMENT_BLOCK2D_LOAD(checked, sg, order, element_type, contrib_type, M, K, WI_rows) \
513+
IMPLEMENT_BLOCK2D_LOAD__(checked, sg, order, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
547514
M, K, WI_rows)
548515

549516
#define IMPLEMENT_BLOCK2D_STORE_SG16(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K) \
@@ -590,13 +557,13 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
590557
#define DEFINE_LOAD_BLOCK2D_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
591558
if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL \
592559
&& (M == 1 || M == 2 || M == 4 || M == 8 || M == 16 || M == 32) \
593-
&& (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && WI_rows == M)) \
560+
&& (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32 && WI_rows == M)) \
594561
) { \
595-
IMPLEMENT_BLOCK2D_LOAD(layout, , sg, order##_, element_type, contrib_type, M, K, WI_rows) \
562+
IMPLEMENT_BLOCK2D_LOAD(, sg, order##_, element_type, contrib_type, M, K, WI_rows) \
596563
}
597564

598565
#define DEFINE_LOAD_CHECKED_BLOCK2D_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows) \
599-
IMPLEMENT_BLOCK2D_LOAD(layout, _CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
566+
IMPLEMENT_BLOCK2D_LOAD(_CHECKED, sg, order##_, element_type, contrib_type, M, K, WI_rows)
600567

601568
#define DEFINE_LOAD_VECTORS_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, address_space) \
602569
if (WI_rows >= M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL \
@@ -768,10 +735,6 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR,
768735

769736
DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 32, ROW_MAJOR, _us, 2)
770737

771-
// This matrix is represented as <8xi16> in LLVM IR, but to be able to read it with 2d block load we have to use i32
772-
// so, contrib type is `int` here and we read <4xi32> from memory, but then we use it as <8xi16>
773-
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, short, int, 8, 16, COL_MAJOR, , 8)
774-
775738
/* PackedA load i16 SG16 for sub group size = 32*/
776739
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 8, 16, ROW_MAJOR, _us, 4)
777740
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 7, 16, ROW_MAJOR, _us, 4)

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,6 @@ static SupportedParams getSupportedParams(const JointMatrixTypeDescription *desc
617617
params.columns = maxSliceBitWidth / desc->bitWidth;
618618
params.bitWidth |= 16;
619619
params.layouts = 1 << LayoutRowMajor;
620-
params.layouts |= 1 << LayoutColumnMajor;
621620
} else if (desc->layout == LayoutPackedB) {
622621
params.rows = maxSliceBitWidth / desc->bitWidth;
623622
params.columns = useSG16 ? 16 : 8;

0 commit comments

Comments
 (0)