Skip to content

Commit 390c40b

Browse files
YuriPlyakhinigcbot
authored andcommitted
SYCL Joint Matrix: OOB support for col major A, B
SYCL Joint Matrix: OOB support for col major A, B
1 parent e09eed4 commit 390c40b

File tree

6 files changed

+115
-76
lines changed

6 files changed

+115
-76
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,6 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
357357
#define MANGLE_PREFETCH_NAME(sg, elem_bitwidth, shape) \
358358
__builtin_spriv_OpJointMatrixPrefetchINTEL##sg##_##shape##_i##elem_bitwidth
359359

360-
#define MANGLE_FILLCHECKED_NAME(elem_bitwidth, WI_rows) \
361-
__builtin_spriv_OpJointMatrixFillCheckedINTEL_i##elem_bitwidth##_##WI_rows
362-
363360
#define SUB_GROUP_LOAD(readop, M, WI_rows, src, dst, stride, contrib_type) \
364361
__private contrib_type *wi_contrib = (__private contrib_type *)dst; \
365362
int ratio = WI_rows / M; \
@@ -499,22 +496,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
499496
#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_Accumulator_ColumnMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
500497
IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR(element_type, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, orig_K)
501498

502-
// ret_num is used only for this load implementation. It is the number of elements in the return vector from block2d call.
503-
// In other cases it is equal to WI_rows but in this case, because we use 32-bit data size to load data,
504-
// while "contrib type" is still 16-bit, we need to use different return type.
505-
#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_PackedA_ColumnMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
506-
if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
507-
long offset = as_long(mem); \
508-
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
509-
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
510-
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
511-
int height = K - 1; \
512-
long x = (offset - baseoffset) / (sizeof(int)); /* in elements */ \
513-
int2 coords = (int2)(x, 0); \
514-
\
515-
OUT_VEC##ret_num(uint) DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(long, int, int, int, int2, int); \
516-
OUT_VEC##ret_num(uint) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
517-
\
499+
#define HELPER_PACKEDA_COL_MAJOR_SHUFFLE(contrib_type, elem_bitwidth, M, K, WI_rows) \
518500
int slid = get_sub_group_local_id(); \
519501
int pack_factor = 2; /* either sizeof(short)/sizeof(char) or sizeof(int)/sizeof(short) */ \
520502
\
@@ -536,7 +518,6 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
536518
return; \
537519
} \
538520
if (elem_bitwidth == 8) { \
539-
int load_pack_factor = sizeof(int) / sizeof(char); \
540521
if (M == WI_rows) { \
541522
char16 *data = (char16*)&res; \
542523
char16 tdata; \
@@ -560,7 +541,25 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
560541
} \
561542
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = *(__private OUT_VEC##WI_rows(u##contrib_type) *)&tdata; \
562543
return; \
563-
} \
544+
}
545+
546+
// ret_num is used only for this load implementation. It is the number of elements in the return vector from block2d call.
547+
// In other cases it is equal to WI_rows but in this case, because we use 32-bit data size to load data,
548+
// while "contrib type" is still 16-bit, we need to use different return type.
549+
#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_PackedA_ColumnMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
550+
if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
551+
long offset = as_long(mem); \
552+
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
553+
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
554+
int pitch = width; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
555+
int height = K - 1; \
556+
long x = (offset - baseoffset) / (sizeof(int)); /* in elements */ \
557+
int2 coords = (int2)(x, 0); \
558+
\
559+
OUT_VEC##ret_num(uint) DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(long, int, int, int, int2, int); \
560+
OUT_VEC##ret_num(uint) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
561+
int load_pack_factor = sizeof(int) / sizeof(element_type); \
562+
HELPER_PACKEDA_COL_MAJOR_SHUFFLE(contrib_type, elem_bitwidth, M, K, WI_rows) \
564563
}
565564

566565
#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_PackedB_RowMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M_VNNI, K_VNNI, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
@@ -599,20 +598,32 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
599598
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_Accumulator_RowMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
600599
IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR(element_type, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K)
601600

601+
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_PackedA_ColumnMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
602+
long offset = as_long(mem); \
603+
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
604+
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
605+
int height_size = height - 1; \
606+
int load_pack_factor = sizeof (int) / sizeof (element_type); \
607+
int2 coords = (int2)(x/load_pack_factor, y); \
608+
/* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
609+
OUT_VEC##ret_num(uint) DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(long, int, int, int, int2, int); \
610+
OUT_VEC##ret_num(uint) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
611+
int sg_size = get_sub_group_size(); \
612+
HELPER_PACKEDA_COL_MAJOR_SHUFFLE(contrib_type, elem_bitwidth, M, K, WI_rows)
613+
602614
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, orig_K) \
603615
long offset = as_long(mem); \
604616
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
605617
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
606618
int height_size = height - 1; \
607-
int2 coords = (int2)(x, y); \
619+
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
620+
int2 coords = (int2)(x/pack_factor, y); \
608621
/* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
609622
OUT_VEC##WI_rows(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, WI_rows, orig_K, M)(long, int, int, int, int2, int); \
610623
OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, WI_rows, orig_K, M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
611624
*(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = *(__private OUT_VEC##WI_rows(u##contrib_type) *)&res; \
612625
return;
613626

614-
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_PackedA_ColumnMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
615-
IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, orig_K)
616627
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_PackedB_ColumnMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
617628
IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, orig_K)
618629
#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_Accumulator_ColumnMajor(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, ret_num, orig_M, orig_K, contrib_M, contrib_K) \
@@ -663,16 +674,16 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
663674
DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, WI_rows, M, contrib_K)(baseoffset, width, height, pitch, coords, val, cacheOpt); \
664675
return;
665676

666-
#define IMPLEMENT_BLOCK2D_STORE_CHECKED_SG16(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, contrib_K) \
677+
#define IMPLEMENT_BLOCK2D_STORE_CHECKED_SG16(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K) \
667678
long offset = as_long(mem); \
668679
int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
669680
int pitch = sizeof (element_type) * stride - 1; /* in bytes */ \
670681
int height_size = height - 1; \
671682
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
672683
int2 coords = (int2)(x / pack_factor, y); \
673-
void DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, M, M, contrib_K)(long, int, int, int, int2, OUT_VEC##M(u##contrib_type), int); \
674-
OUT_VEC##M(u##contrib_type) val = *(OUT_VEC##M(u##contrib_type) *)src; \
675-
DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, M, M, contrib_K)(offset, width_size, height_size, pitch, coords, val, cacheOpt); \
684+
void DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, WI_rows, M, contrib_K)(long, int, int, int, int2, OUT_VEC##WI_rows(u##contrib_type), int); \
685+
OUT_VEC##WI_rows(u##contrib_type) val = *(OUT_VEC##WI_rows(u##contrib_type) *)src; \
686+
DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, WI_rows, M, contrib_K)(offset, width_size, height_size, pitch, coords, val, cacheOpt); \
676687
return;
677688

678689
// layout can be PackedA_RowMajor, PackedB_ColumnMajor, PackedB_PackedB, etc.
@@ -878,7 +889,7 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 32, ROW_MAJOR,
878889
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, short, short, 8, 16, COL_MAJOR, , 8)
879890

880891
/* PackedA load i16 SG16 for sub group size = 32*/
881-
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 8, 16, ROW_MAJOR, _us, 4)
892+
DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 8, 16, ROW_MAJOR, _us, 4)
882893
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 7, 16, ROW_MAJOR, _us, 4)
883894
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 6, 16, ROW_MAJOR, _us, 3)
884895
DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 5, 16, ROW_MAJOR, _us, 3)
@@ -888,7 +899,7 @@ DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 2, 16, ROW_MAJOR, _us, 1)
888899
// DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR, _us, 1) same as for subgroup 16
889900

890901
/* PackedA load i16 SG16 Col Major for sub group size = 32*/
891-
DEFINE_LOAD(PackedA_ColumnMajor, _SG16, short, short, 8, 16, COL_MAJOR, , 4)
902+
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, short, short, 8, 16, COL_MAJOR, , 4)
892903

893904
/* PackedA load i8 SG16 */
894905
DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, char, short, 8, 32, ROW_MAJOR, _us, 8)
@@ -904,7 +915,7 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, char, short, 1, 32, ROW_MAJOR,
904915
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, char, short, 8, 32, COL_MAJOR, , 8)
905916

906917
/* PackedA load i8 SG16 for sub group size 32*/
907-
DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 8, 32, ROW_MAJOR, _us, 4)
918+
DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, char, short, 8, 32, ROW_MAJOR, _us, 4)
908919
DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 7, 32, ROW_MAJOR, _us, 4)
909920
DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 6, 32, ROW_MAJOR, _us, 3)
910921
DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 5, 32, ROW_MAJOR, _us, 3)
@@ -914,7 +925,7 @@ DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 2, 32, ROW_MAJOR, _us, 1)
914925
// DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 1, 32, ROW_MAJOR, _us, 1) same as for subgroup 16
915926

916927
/* PackedA load i8 SG16 Col Major for sub group size 32*/
917-
DEFINE_LOAD(PackedA_ColumnMajor, _SG16, char, short, 8, 32, COL_MAJOR, , 4)
928+
DEFINE_LOAD_AND_CHECKED(PackedA_ColumnMajor, _SG16, char, short, 8, 32, COL_MAJOR, , 4)
918929

919930
/* A load tf32 SG16 */
920931
DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, int, int, 8, 8, ROW_MAJOR, , 4)
@@ -942,19 +953,19 @@ DEFINE_LOAD_AND_CHECKED(PackedB_PackedB, _SG16, short, int, 16, 32, ROW_MAJO
942953
DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, short, int, 16, 32, VNNI_TX, , 16)
943954

944955
/* PackedB load i16 SG16 for sub group size = 32*/
945-
DEFINE_LOAD(PackedB_ColumnMajor, _SG16, short, int, 8, 32, COL_MAJOR, , 4)
946-
DEFINE_LOAD(PackedB_PackedB, _SG16, short, int, 8, 32, ROW_MAJOR, , 4)
947-
DEFINE_LOAD(PackedB_RowMajor, _SG16, short, int, 8, 32, VNNI_TX, , 4)
956+
DEFINE_LOAD_AND_CHECKED(PackedB_ColumnMajor, _SG16, short, int, 8, 32, COL_MAJOR, , 4)
957+
DEFINE_LOAD_AND_CHECKED(PackedB_PackedB, _SG16, short, int, 8, 32, ROW_MAJOR, , 4)
958+
DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, short, int, 8, 32, VNNI_TX, , 4)
948959

949960
/* PackedB load i8 SG16*/
950961
DEFINE_LOAD_AND_CHECKED(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 8)
951962
DEFINE_LOAD_AND_CHECKED(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR, , 8)
952963
DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, char, int, 8, 64, VNNI_TX, , 8)
953964

954965
/* PackedB load i8 SG16 for sub group size 32*/
955-
DEFINE_LOAD(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 4)
956-
DEFINE_LOAD(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR, , 4)
957-
DEFINE_LOAD(PackedB_RowMajor, _SG16, char, int, 8, 64, VNNI_TX, , 4)
966+
DEFINE_LOAD_AND_CHECKED(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 4)
967+
DEFINE_LOAD_AND_CHECKED(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR, , 4)
968+
DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, char, int, 8, 64, VNNI_TX, , 4)
958969

959970
/* B load tf32 SG16 */
960971
DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, int, int, 8, 16, ROW_MAJOR, , 8)
@@ -1048,7 +1059,7 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
10481059
}
10491060

10501061
#define DEFINE_STORE_CHECKED_BLOCK2D_IMPL(sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, order, WI_rows) \
1051-
IMPLEMENT_BLOCK2D_STORE_CHECKED##sg(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
1062+
IMPLEMENT_BLOCK2D_STORE_CHECKED##sg(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
10521063

10531064
// set block_opt to false to disable block non-continous optimization per one built-in as a workaround
10541065
#define DEFINE_STORE_VECTORS_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, order, us, WI_rows, block_opt, address_space) \
@@ -1311,7 +1322,7 @@ DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 2, tr
13111322
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, int, 1, 16, COL_MAJOR, , 1, true)
13121323

13131324
/* Acc i32 SG16 for subgroup 32*/
1314-
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 8, 16, ROW_MAJOR, , 4, true)
1325+
DEFINE_STORE_AND_CHECKED(Accumulator_RowMajor, _SG16, int, int, 8, 16, ROW_MAJOR, , 4, true)
13151326
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 7, 16, ROW_MAJOR, , 4, true)
13161327
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 6, 16, ROW_MAJOR, , 3, true)
13171328
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, int, 5, 16, ROW_MAJOR, , 3, true)
@@ -2288,31 +2299,48 @@ DEFINE_STORE_LARGE_1(Accumulator_RowMajor, 1, 64)
22882299

22892300
DEFINE_STORE_CHECKED_LARGE_1(Accumulator_RowMajor, 1, 64)
22902301

2291-
#define DEFINE_FILLCHECKED_IMPL(element_type, elem_bitwidth, WI_rows) \
2292-
INLINE void MANGLE_FILLCHECKED_NAME(elem_bitwidth, WI_rows) (__private char *dst, int y, int x, int height, int width, element_type value) { \
2302+
// FillChecked implementation
2303+
2304+
#define MANGLE_FILLCHECKED_NAME(elem_bitwidth, contrib_bitwidth, K, WI_rows) \
2305+
__builtin_spirv_OpJointMatrixFillCheckedINTEL_i##elem_bitwidth##_i##contrib_bitwidth##_k##K##_wi##WI_rows
2306+
2307+
#define DEFINE_FILLCHECKED_IMPL(element_type, elem_bitwidth, contrib_bitwidth, K, WI_rows) \
2308+
INLINE void MANGLE_FILLCHECKED_NAME(elem_bitwidth, contrib_bitwidth, K, WI_rows) (__private char *dst, int y, int x, int height, int width, element_type value) { \
22932309
int slid = get_sub_group_local_id(); \
2310+
int sg_size = get_sub_group_size(); \
2311+
int pack_factor = contrib_bitwidth / elem_bitwidth; \
2312+
int col_sg_ratio = (sg_size * pack_factor) / K; \
22942313
__private element_type *wi_contrib = (__private element_type *) dst; \
22952314
for (int i = 0; i < WI_rows; i++) { \
2296-
element_type fill_value = ((slid < width - x) && (i < height - y)) ? value : 0; \
2315+
element_type fill_value = slid % K < width - x && i * col_sg_ratio < height - y ? value : 0; \
22972316
wi_contrib[i] = fill_value; \
22982317
} \
22992318
}
23002319

2301-
#define DEFINE_FILLCHECKED__(element_type, elem_bitwidth, WI_rows) \
2302-
DEFINE_FILLCHECKED_IMPL(element_type, elem_bitwidth, WI_rows)
2320+
#define DEFINE_FILLCHECKED__(element_type, elem_bitwidth, contrib_bitwidth, K, WI_rows) \
2321+
DEFINE_FILLCHECKED_IMPL(element_type, elem_bitwidth, contrib_bitwidth, K, WI_rows)
2322+
2323+
#define DEFINE_FILLCHECKED(element_type, contrib_type, K, WI_rows) \
2324+
DEFINE_FILLCHECKED__(element_type, BITWIDTH(element_type), BITWIDTH(contrib_type), K, WI_rows)
2325+
2326+
#define DEFINE_FILLCHECKED_K(element_type, contrib_type, K) \
2327+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 1) \
2328+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 2) \
2329+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 4) \
2330+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 8) \
2331+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 16) \
2332+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 32) \
2333+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 64) \
2334+
DEFINE_FILLCHECKED(element_type, contrib_type, K, 128)
23032335

2304-
#define DEFINE_FILLCHECKED(element_type, WI_rows) \
2305-
DEFINE_FILLCHECKED__(element_type, BITWIDTH(element_type), WI_rows)
2336+
#define DEFINE_FILLCHECKED_CONTRIB(element_type, contrib_type) \
2337+
DEFINE_FILLCHECKED_K(element_type, contrib_type, 8) \
2338+
DEFINE_FILLCHECKED_K(element_type, contrib_type, 16) \
2339+
DEFINE_FILLCHECKED_K(element_type, contrib_type, 32) \
23062340

23072341
#define DEFINE_FILLCHECKED_GROUP(element_type) \
2308-
DEFINE_FILLCHECKED(element_type, 1) \
2309-
DEFINE_FILLCHECKED(element_type, 2) \
2310-
DEFINE_FILLCHECKED(element_type, 4) \
2311-
DEFINE_FILLCHECKED(element_type, 8) \
2312-
DEFINE_FILLCHECKED(element_type, 16) \
2313-
DEFINE_FILLCHECKED(element_type, 32) \
2314-
DEFINE_FILLCHECKED(element_type, 64) \
2315-
DEFINE_FILLCHECKED(element_type, 128)
2342+
DEFINE_FILLCHECKED_CONTRIB(element_type, short) \
2343+
DEFINE_FILLCHECKED_CONTRIB(element_type, int) \
23162344

23172345
DEFINE_FILLCHECKED_GROUP(char)
23182346
DEFINE_FILLCHECKED_GROUP(short)

0 commit comments

Comments
 (0)