@@ -357,9 +357,6 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
357357#define MANGLE_PREFETCH_NAME (sg , elem_bitwidth , shape ) \
358358 __builtin_spriv_OpJointMatrixPrefetchINTEL##sg##_##shape##_i##elem_bitwidth
359359
360- #define MANGLE_FILLCHECKED_NAME (elem_bitwidth , WI_rows ) \
361- __builtin_spriv_OpJointMatrixFillCheckedINTEL_i##elem_bitwidth##_##WI_rows
362-
363360#define SUB_GROUP_LOAD (readop , M , WI_rows , src , dst , stride , contrib_type ) \
364361 __private contrib_type *wi_contrib = (__private contrib_type *)dst; \
365362 int ratio = WI_rows / M; \
@@ -499,22 +496,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
499496#define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_Accumulator_ColumnMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
500497 IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR(element_type, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, orig_K)
501498
502- // ret_num is used only for this load implementation. It is the number of elements in the return vector from block2d call.
503- // In other cases it is equal to WI_rows but in this case, because we use 32-bit data size to load data,
504- // while "contrib type" is still 16-bit, we need to use different return type.
505- #define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_PackedA_ColumnMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
506- if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
507- long offset = as_long (mem ); \
508- long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
509- int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
510- int pitch = width ; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
511- int height = K - 1 ; \
512- long x = (offset - baseoffset ) / (sizeof (int )); /* in elements */ \
513- int2 coords = (int2 )(x , 0 ); \
514- \
515- OUT_VEC ##ret_num (uint) DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(long, int, int, int, int2, int); \
516- OUT_VEC##ret_num(uint) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
517- \
499+ #define HELPER_PACKEDA_COL_MAJOR_SHUFFLE (contrib_type , elem_bitwidth , M , K , WI_rows ) \
518500 int slid = get_sub_group_local_id(); \
519501 int pack_factor = 2; /* either sizeof(short)/sizeof(char) or sizeof(int)/sizeof(short) */ \
520502\
@@ -536,7 +518,6 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
536518 return; \
537519 } \
538520 if (elem_bitwidth == 8) { \
539- int load_pack_factor = sizeof(int) / sizeof(char); \
540521 if (M == WI_rows) { \
541522 char16 *data = (char16*)&res; \
542523 char16 tdata; \
@@ -560,7 +541,25 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
560541 } \
561542 * (__private OUT_VEC ##WI_rows (u##contrib_type) *)dst = *(__private OUT_VEC##WI_rows(u##contrib_type) *)&tdata; \
562543 return; \
563- } \
544+ }
545+
546+ // ret_num is used only for this load implementation. It is the number of elements in the return vector from block2d call.
547+ // In other cases it is equal to WI_rows but in this case, because we use 32-bit data size to load data,
548+ // while "contrib type" is still 16-bit, we need to use different return type.
549+ #define IMPLEMENT_BLOCK2D_LOAD_SG16_COL_MAJOR_PackedA_ColumnMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
550+ if (M*sizeof(element_type) <= MAX_ROW_BYTES_2D_BLOCK_LOAD) { /* For 2D loads (block2d width)*(data size) must be <= MAX_ROW_BYTES_2D_BLOCK_LOAD */ \
551+ long offset = as_long (mem ); \
552+ long baseoffset = offset & (~0x3f ); /* align to 64-byte */ \
553+ int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
554+ int pitch = width ; /* JointMatrices are expected to be contiguous in memory, without padding at the end of a row */ \
555+ int height = K - 1 ; \
556+ long x = (offset - baseoffset ) / (sizeof (int )); /* in elements */ \
557+ int2 coords = (int2 )(x , 0 ); \
558+ \
559+ OUT_VEC ##ret_num (uint) DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(long, int, int, int, int2, int); \
560+ OUT_VEC##ret_num(uint) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
561+ int load_pack_factor = sizeof(int) / sizeof(element_type); \
562+ HELPER_PACKEDA_COL_MAJOR_SHUFFLE(contrib_type, elem_bitwidth, M, K, WI_rows) \
564563 }
565564
566565#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX_PackedB_RowMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M_VNNI , K_VNNI , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
@@ -599,20 +598,32 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
599598#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR_Accumulator_RowMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
600599 IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_ROW_MAJOR(element_type, contrib_type, contrib_bitwidth, M, K, WI_rows, contrib_K)
601600
601+ #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_PackedA_ColumnMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
602+ long offset = as_long(mem); \
603+ int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
604+ int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
605+ int height_size = height - 1 ; \
606+ int load_pack_factor = sizeof (int ) / sizeof (element_type ); \
607+ int2 coords = (int2 )(x /load_pack_factor , y ); \
608+ /* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
609+ OUT_VEC ##ret_num (uint) DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(long, int, int, int, int2, int); \
610+ OUT_VEC##ret_num(uint) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(32, ret_num, K, contrib_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
611+ int sg_size = get_sub_group_size(); \
612+ HELPER_PACKEDA_COL_MAJOR_SHUFFLE(contrib_type, elem_bitwidth, M, K, WI_rows)
613+
602614#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , orig_M , orig_K ) \
603615 long offset = as_long(mem); \
604616 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
605617 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
606618 int height_size = height - 1 ; \
607- int2 coords = (int2 )(x , y ); \
619+ int pack_factor = sizeof (contrib_type ) / sizeof (element_type ); \
620+ int2 coords = (int2 )(x /pack_factor , y ); \
608621 /* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
609622 OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, WI_rows, orig_K, M)(long, int, int, int, int2, int); \
610623 OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, WI_rows, orig_K, M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
611624 *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = *(__private OUT_VEC##WI_rows(u##contrib_type) *)&res; \
612625 return;
613626
614- #define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_PackedA_ColumnMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
615- IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, orig_K)
616627#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_PackedB_ColumnMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
617628 IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, orig_M, orig_K)
618629#define IMPLEMENT_BLOCK2D_LOAD_CHECKED_SG16_COL_MAJOR_Accumulator_ColumnMajor (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , ret_num , orig_M , orig_K , contrib_M , contrib_K ) \
@@ -663,16 +674,16 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
663674 DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, WI_rows, M, contrib_K)(baseoffset, width, height, pitch, coords, val, cacheOpt); \
664675 return;
665676
666- #define IMPLEMENT_BLOCK2D_STORE_CHECKED_SG16 (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , contrib_K ) \
677+ #define IMPLEMENT_BLOCK2D_STORE_CHECKED_SG16 (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , WI_rows , contrib_K ) \
667678 long offset = as_long(mem); \
668679 int width_size = sizeof (element_type) * width - 1; /* in bytes */ \
669680 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
670681 int height_size = height - 1 ; \
671682 int pack_factor = sizeof (contrib_type ) / sizeof (element_type ); \
672683 int2 coords = (int2 )(x / pack_factor , y ); \
673- void DEFINE_BLOCK2D_RW_NAME (write , , contrib_bitwidth , M , M , contrib_K )(long , int , int , int , int2 , OUT_VEC ##M (u##contrib_type), int); \
674- OUT_VEC##M (u##contrib_type) val = *(OUT_VEC##M (u##contrib_type) *)src; \
675- DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, M , M, contrib_K)(offset, width_size, height_size, pitch, coords, val, cacheOpt); \
684+ void DEFINE_BLOCK2D_RW_NAME (write , , contrib_bitwidth , WI_rows , M , contrib_K )(long , int , int , int , int2 , OUT_VEC ##WI_rows (u##contrib_type), int); \
685+ OUT_VEC##WI_rows (u##contrib_type) val = *(OUT_VEC##WI_rows (u##contrib_type) *)src; \
686+ DEFINE_BLOCK2D_RW_NAME(write, , contrib_bitwidth, WI_rows , M, contrib_K)(offset, width_size, height_size, pitch, coords, val, cacheOpt); \
676687 return;
677688
678689// layout can be PackedA_RowMajor, PackedB_ColumnMajor, PackedB_PackedB, etc.
@@ -878,7 +889,7 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 32, ROW_MAJOR,
878889DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , short , short , 8 , 16 , COL_MAJOR , , 8 )
879890
880891/* PackedA load i16 SG16 for sub group size = 32*/
881- DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 8 , 16 , ROW_MAJOR , _us , 4 )
892+ DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 8 , 16 , ROW_MAJOR , _us , 4 )
882893DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 7 , 16 , ROW_MAJOR , _us , 4 )
883894DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 6 , 16 , ROW_MAJOR , _us , 3 )
884895DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 5 , 16 , ROW_MAJOR , _us , 3 )
@@ -888,7 +899,7 @@ DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 2, 16, ROW_MAJOR, _us, 1)
888899// DEFINE_LOAD(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR, _us, 1) same as for subgroup 16
889900
890901/* PackedA load i16 SG16 Col Major for sub group size = 32*/
891- DEFINE_LOAD (PackedA_ColumnMajor , _SG16 , short , short , 8 , 16 , COL_MAJOR , , 4 )
902+ DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , short , short , 8 , 16 , COL_MAJOR , , 4 )
892903
893904/* PackedA load i8 SG16 */
894905DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , char , short , 8 , 32 , ROW_MAJOR , _us , 8 )
@@ -904,7 +915,7 @@ DEFINE_LOAD_AND_CHECKED(PackedA_RowMajor, _SG16, char, short, 1, 32, ROW_MAJOR,
904915DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , char , short , 8 , 32 , COL_MAJOR , , 8 )
905916
906917/* PackedA load i8 SG16 for sub group size 32*/
907- DEFINE_LOAD (PackedA_RowMajor , _SG16 , char , short , 8 , 32 , ROW_MAJOR , _us , 4 )
918+ DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , char , short , 8 , 32 , ROW_MAJOR , _us , 4 )
908919DEFINE_LOAD (PackedA_RowMajor , _SG16 , char , short , 7 , 32 , ROW_MAJOR , _us , 4 )
909920DEFINE_LOAD (PackedA_RowMajor , _SG16 , char , short , 6 , 32 , ROW_MAJOR , _us , 3 )
910921DEFINE_LOAD (PackedA_RowMajor , _SG16 , char , short , 5 , 32 , ROW_MAJOR , _us , 3 )
@@ -914,7 +925,7 @@ DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 2, 32, ROW_MAJOR, _us, 1)
914925// DEFINE_LOAD(PackedA_RowMajor, _SG16, char, short, 1, 32, ROW_MAJOR, _us, 1) same as for subgroup 16
915926
916927/* PackedA load i8 SG16 Col Major for sub group size 32*/
917- DEFINE_LOAD (PackedA_ColumnMajor , _SG16 , char , short , 8 , 32 , COL_MAJOR , , 4 )
928+ DEFINE_LOAD_AND_CHECKED (PackedA_ColumnMajor , _SG16 , char , short , 8 , 32 , COL_MAJOR , , 4 )
918929
919930/* A load tf32 SG16 */
920931DEFINE_LOAD_AND_CHECKED (PackedA_RowMajor , _SG16 , int , int , 8 , 8 , ROW_MAJOR , , 4 )
@@ -942,19 +953,19 @@ DEFINE_LOAD_AND_CHECKED(PackedB_PackedB, _SG16, short, int, 16, 32, ROW_MAJO
942953DEFINE_LOAD_AND_CHECKED (PackedB_RowMajor , _SG16 , short , int , 16 , 32 , VNNI_TX , , 16 )
943954
944955/* PackedB load i16 SG16 for sub group size = 32*/
945- DEFINE_LOAD (PackedB_ColumnMajor , _SG16 , short , int , 8 , 32 , COL_MAJOR , , 4 )
946- DEFINE_LOAD (PackedB_PackedB , _SG16 , short , int , 8 , 32 , ROW_MAJOR , , 4 )
947- DEFINE_LOAD (PackedB_RowMajor , _SG16 , short , int , 8 , 32 , VNNI_TX , , 4 )
956+ DEFINE_LOAD_AND_CHECKED (PackedB_ColumnMajor , _SG16 , short , int , 8 , 32 , COL_MAJOR , , 4 )
957+ DEFINE_LOAD_AND_CHECKED (PackedB_PackedB , _SG16 , short , int , 8 , 32 , ROW_MAJOR , , 4 )
958+ DEFINE_LOAD_AND_CHECKED (PackedB_RowMajor , _SG16 , short , int , 8 , 32 , VNNI_TX , , 4 )
948959
949960/* PackedB load i8 SG16*/
950961DEFINE_LOAD_AND_CHECKED (PackedB_ColumnMajor , _SG16 , char , int , 8 , 64 , COL_MAJOR , , 8 )
951962DEFINE_LOAD_AND_CHECKED (PackedB_PackedB , _SG16 , char , int , 8 , 64 , ROW_MAJOR , , 8 )
952963DEFINE_LOAD_AND_CHECKED (PackedB_RowMajor , _SG16 , char , int , 8 , 64 , VNNI_TX , , 8 )
953964
954965/* PackedB load i8 SG16 for sub group size 32*/
955- DEFINE_LOAD (PackedB_ColumnMajor , _SG16 , char , int , 8 , 64 , COL_MAJOR , , 4 )
956- DEFINE_LOAD (PackedB_PackedB , _SG16 , char , int , 8 , 64 , ROW_MAJOR , , 4 )
957- DEFINE_LOAD (PackedB_RowMajor , _SG16 , char , int , 8 , 64 , VNNI_TX , , 4 )
966+ DEFINE_LOAD_AND_CHECKED (PackedB_ColumnMajor , _SG16 , char , int , 8 , 64 , COL_MAJOR , , 4 )
967+ DEFINE_LOAD_AND_CHECKED (PackedB_PackedB , _SG16 , char , int , 8 , 64 , ROW_MAJOR , , 4 )
968+ DEFINE_LOAD_AND_CHECKED (PackedB_RowMajor , _SG16 , char , int , 8 , 64 , VNNI_TX , , 4 )
958969
959970/* B load tf32 SG16 */
960971DEFINE_LOAD_AND_CHECKED (PackedB_RowMajor , _SG16 , int , int , 8 , 16 , ROW_MAJOR , , 8 )
@@ -1048,7 +1059,7 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
10481059 }
10491060
10501061#define DEFINE_STORE_CHECKED_BLOCK2D_IMPL (sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , order , WI_rows ) \
1051- IMPLEMENT_BLOCK2D_STORE_CHECKED##sg(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
1062+ IMPLEMENT_BLOCK2D_STORE_CHECKED##sg(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, WI_rows, MATH_DIV(K, MATH_DIV(contrib_bitwidth, elem_bitwidth)))
10521063
10531064// set block_opt to false to disable block non-continous optimization per one built-in as a workaround
10541065#define DEFINE_STORE_VECTORS_IMPL (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , order , us , WI_rows , block_opt , address_space ) \
@@ -1311,7 +1322,7 @@ DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 2, tr
13111322DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , int , 1 , 16 , COL_MAJOR , , 1 , true)
13121323
13131324/* Acc i32 SG16 for subgroup 32*/
1314- DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , int , 8 , 16 , ROW_MAJOR , , 4 , true)
1325+ DEFINE_STORE_AND_CHECKED (Accumulator_RowMajor , _SG16 , int , int , 8 , 16 , ROW_MAJOR , , 4 , true)
13151326DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , int , 7 , 16 , ROW_MAJOR , , 4 , true)
13161327DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , int , 6 , 16 , ROW_MAJOR , , 3 , true)
13171328DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , int , 5 , 16 , ROW_MAJOR , , 3 , true)
@@ -2288,31 +2299,48 @@ DEFINE_STORE_LARGE_1(Accumulator_RowMajor, 1, 64)
22882299
22892300DEFINE_STORE_CHECKED_LARGE_1 (Accumulator_RowMajor , 1 , 64 )
22902301
2291- #define DEFINE_FILLCHECKED_IMPL (element_type , elem_bitwidth , WI_rows ) \
2292- INLINE void MANGLE_FILLCHECKED_NAME(elem_bitwidth, WI_rows) (__private char *dst, int y, int x, int height, int width, element_type value) { \
2302+ // FillChecked implementation
2303+
2304+ #define MANGLE_FILLCHECKED_NAME (elem_bitwidth , contrib_bitwidth , K , WI_rows ) \
2305+ __builtin_spirv_OpJointMatrixFillCheckedINTEL_i##elem_bitwidth##_i##contrib_bitwidth##_k##K##_wi##WI_rows
2306+
2307+ #define DEFINE_FILLCHECKED_IMPL (element_type , elem_bitwidth , contrib_bitwidth , K , WI_rows ) \
2308+ INLINE void MANGLE_FILLCHECKED_NAME(elem_bitwidth, contrib_bitwidth, K, WI_rows) (__private char *dst, int y, int x, int height, int width, element_type value) { \
22932309 int slid = get_sub_group_local_id(); \
2310+ int sg_size = get_sub_group_size(); \
2311+ int pack_factor = contrib_bitwidth / elem_bitwidth; \
2312+ int col_sg_ratio = (sg_size * pack_factor) / K; \
22942313 __private element_type *wi_contrib = (__private element_type *) dst; \
22952314 for (int i = 0; i < WI_rows; i++) { \
2296- element_type fill_value = (( slid < width - x) && (i < height - y)) ? value : 0; \
2315+ element_type fill_value = slid % K < width - x && i * col_sg_ratio < height - y ? value : 0; \
22972316 wi_contrib[i] = fill_value; \
22982317 } \
22992318}
23002319
2301- #define DEFINE_FILLCHECKED__ (element_type , elem_bitwidth , WI_rows ) \
2302- DEFINE_FILLCHECKED_IMPL(element_type, elem_bitwidth, WI_rows)
2320+ #define DEFINE_FILLCHECKED__ (element_type , elem_bitwidth , contrib_bitwidth , K , WI_rows ) \
2321+ DEFINE_FILLCHECKED_IMPL(element_type, elem_bitwidth, contrib_bitwidth, K, WI_rows)
2322+
2323+ #define DEFINE_FILLCHECKED (element_type , contrib_type , K , WI_rows ) \
2324+ DEFINE_FILLCHECKED__(element_type, BITWIDTH(element_type), BITWIDTH(contrib_type), K, WI_rows)
2325+
2326+ #define DEFINE_FILLCHECKED_K (element_type , contrib_type , K ) \
2327+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 1) \
2328+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 2) \
2329+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 4) \
2330+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 8) \
2331+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 16) \
2332+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 32) \
2333+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 64) \
2334+ DEFINE_FILLCHECKED(element_type, contrib_type, K, 128)
23032335
2304- #define DEFINE_FILLCHECKED (element_type , WI_rows ) \
2305- DEFINE_FILLCHECKED__(element_type, BITWIDTH(element_type), WI_rows)
2336+ #define DEFINE_FILLCHECKED_CONTRIB (element_type , contrib_type ) \
2337+ DEFINE_FILLCHECKED_K(element_type, contrib_type, 8) \
2338+ DEFINE_FILLCHECKED_K(element_type, contrib_type, 16) \
2339+ DEFINE_FILLCHECKED_K(element_type, contrib_type, 32) \
23062340
23072341#define DEFINE_FILLCHECKED_GROUP (element_type ) \
2308- DEFINE_FILLCHECKED(element_type, 1) \
2309- DEFINE_FILLCHECKED(element_type, 2) \
2310- DEFINE_FILLCHECKED(element_type, 4) \
2311- DEFINE_FILLCHECKED(element_type, 8) \
2312- DEFINE_FILLCHECKED(element_type, 16) \
2313- DEFINE_FILLCHECKED(element_type, 32) \
2314- DEFINE_FILLCHECKED(element_type, 64) \
2315- DEFINE_FILLCHECKED(element_type, 128)
2342+ DEFINE_FILLCHECKED_CONTRIB(element_type, short) \
2343+ DEFINE_FILLCHECKED_CONTRIB(element_type, int) \
23162344
23172345DEFINE_FILLCHECKED_GROUP (char )
23182346DEFINE_FILLCHECKED_GROUP (short )
0 commit comments