@@ -922,6 +922,7 @@ DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, char, int, 8, 64, VNNI_TX,
922922/* PackedB load i8 SG16 for sub group size 32*/
923923DEFINE_LOAD (PackedB_ColumnMajor , _SG16 , char , int , 8 , 64 , COL_MAJOR , , 4 )
924924DEFINE_LOAD (PackedB_PackedB , _SG16 , char , int , 8 , 64 , ROW_MAJOR , , 4 )
925+ DEFINE_LOAD (PackedB_RowMajor , _SG16 , char , int , 8 , 64 , VNNI_TX , , 4 )
925926
926927/* B load tf32 SG16 */
927928DEFINE_LOAD_AND_CHECKED (PackedB_RowMajor , _SG16 , int , int , 8 , 16 , ROW_MAJOR , , 8 )
@@ -1041,15 +1042,21 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
10411042 return; \
10421043 }
10431044
1044- #define DEFINE_STORE_SCALAR_IMPL (element_type , contrib_type , M , K , order , WI_rows ) \
1045- contrib_type *ptr = (contrib_type *)mem; \
1045+ #define DEFINE_STORE_SCALAR_IMPL (layout , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , order , WI_rows ) \
10461046 int slid = get_sub_group_local_id(); \
10471047 int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
1048- stride = stride / pack_factor ; \
1048+ int elem_num = (M * K) / sg_size ; \
10491049 int sg_cols = K / pack_factor; \
1050+ int skip_factor = sg_size / sg_cols; \
1051+ if (_##layout == _PackedA_ColumnMajor && elem_bitwidth == 8 && contrib_bitwidth == 16) { \
1052+ for (int i = 0; i < elem_num; i++) \
1053+ mem[(i % pack_factor) * stride + ((slid * pack_factor) % K) * stride + (i / pack_factor) * skip_factor + (slid * pack_factor) / K] = src[i]; \
1054+ return; \
1055+ } \
1056+ contrib_type *ptr = (contrib_type *)mem; \
1057+ stride = stride / pack_factor; \
10501058 __private contrib_type *slice = (__private contrib_type *)src; \
10511059 if(sg_size >= sg_cols) { \
1052- int skip_factor = sg_size / sg_cols; \
10531060 for (int i = 0; i < WI_rows; i++) { \
10541061 if ( (i*skip_factor + slid/sg_cols) < M ) \
10551062 ptr[IND##order(slid, stride, skip_factor, i, sg_cols)] = slice[i]; \
@@ -1074,22 +1081,22 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
10741081 } else { \
10751082 DEFINE_STORE_VECTORS_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, us, WI_rows, block_opt, AS_LOCAL) \
10761083 } \
1077- DEFINE_STORE_SCALAR_IMPL(element_type, contrib_type, M, K, _##order, WI_rows) \
1084+ DEFINE_STORE_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth , M, K, _##order, WI_rows) \
10781085 }
10791086
10801087#define DEFINE_STORE_IMPL_AS_LOCAL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows , block_opt ) \
10811088 INLINE void MANGLE_STORE_NAME(layout, sg, elem_bitwidth, shape, WI_rows, local) (char *mem, __private char *src, long stride, int cacheOpt) { \
10821089 int sg_size = get_sub_group_size(); \
10831090 DEFINE_STORE_VECTORS_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, us, WI_rows, block_opt, AS_LOCAL) \
1084- DEFINE_STORE_SCALAR_IMPL(element_type, contrib_type, M, K, _##order, WI_rows) \
1091+ DEFINE_STORE_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth , M, K, _##order, WI_rows) \
10851092 }
10861093
10871094#define DEFINE_STORE_IMPL_AS_GLOBAL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows , block_opt ) \
10881095 INLINE void MANGLE_STORE_NAME(layout, sg, elem_bitwidth, shape, WI_rows, global) (char *mem, __private char *src, long stride, int cacheOpt) { \
10891096 int sg_size = get_sub_group_size(); \
10901097 DEFINE_STORE_BLOCK2D_IMPL(sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
10911098 DEFINE_STORE_VECTORS_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, us, WI_rows, block_opt, AS_GLOBAL) \
1092- DEFINE_STORE_SCALAR_IMPL(element_type, contrib_type, M, K, _##order, WI_rows) \
1099+ DEFINE_STORE_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth , M, K, _##order, WI_rows) \
10931100 }
10941101
10951102#define DEFINE_STORE_CHECKED_IMPL (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , WI_rows , block_opt ) \
@@ -1149,6 +1156,9 @@ DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 6, 32, ROW_MAJOR, _us, 6, fa
11491156DEFINE_STORE (PackedA_RowMajor , _SG16 , char , short , 7 , 32 , ROW_MAJOR , _us , 7 , false)
11501157DEFINE_STORE_AND_CHECKED (PackedA_RowMajor , _SG16 , char , short , 8 , 32 , ROW_MAJOR , _us , 8 , false)
11511158
1159+ /* PackedA store i8 SG16 Col Major*/
1160+ DEFINE_STORE (PackedA_ColumnMajor , _SG16 , char , short , 8 , 32 , COL_MAJOR , _us , 8 , true)
1161+
11521162/* PackedA store i8 SG16 for subgroup 32*/
11531163// DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 1, 32, ROW_MAJOR, _us, 1, false) same as for subgroup 16
11541164DEFINE_STORE (PackedA_RowMajor , _SG16 , char , short , 2 , 32 , ROW_MAJOR , _us , 1 , false)
@@ -1159,6 +1169,9 @@ DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 6, 32, ROW_MAJOR, _us, 3, fa
11591169DEFINE_STORE (PackedA_RowMajor , _SG16 , char , short , 7 , 32 , ROW_MAJOR , _us , 4 , false)
11601170DEFINE_STORE (PackedA_RowMajor , _SG16 , char , short , 8 , 32 , ROW_MAJOR , _us , 4 , false)
11611171
1172+ /* PackedA store i8 SG16 Col Major for sg 32*/
1173+ DEFINE_STORE (PackedA_ColumnMajor , _SG16 , char , short , 8 , 32 , COL_MAJOR , _us , 4 , true)
1174+
11621175/* PackedA store i16 SG16 */
11631176DEFINE_STORE_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 1 , 16 , ROW_MAJOR , _us , 1 , false)
11641177DEFINE_STORE_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 2 , 16 , ROW_MAJOR , _us , 2 , false)
@@ -1171,6 +1184,12 @@ DEFINE_STORE_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 8, 16, ROW_MAJOR
11711184
11721185DEFINE_STORE_AND_CHECKED (PackedA_RowMajor , _SG16 , short , short , 1 , 32 , ROW_MAJOR , _us , 2 , true)
11731186
1187+ /* PackedA store i16 SG16 Col Major*/
1188+ DEFINE_STORE (PackedA_ColumnMajor , _SG16 , short , short , 8 , 16 , COL_MAJOR , _us , 8 , true)
1189+
1190+ /* PackedA store i16 SG16 Col Major for sg size 32*/
1191+ DEFINE_STORE (PackedA_ColumnMajor , _SG16 , short , short , 8 , 16 , COL_MAJOR , _us , 4 , true)
1192+
11741193/* PackedA store i16 SG16 for sub group size 32 */
11751194// DEFINE_STORE(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR, _us, 1, false) same as for subgroup 16
11761195DEFINE_STORE (PackedA_RowMajor , _SG16 , short , short , 2 , 16 , ROW_MAJOR , _us , 1 , false)
@@ -1197,6 +1216,7 @@ DEFINE_STORE_AND_CHECKED(PackedB_PackedB, _SG16, short, int, 8, 32, ROW_MAJO
11971216DEFINE_STORE (PackedB_RowMajor , _SG16 , short , int , 8 , 32 , VNNI_TX , , 8 , true)
11981217
11991218/* PackedB store i16 SG16 for subgroup 32*/
1219+ DEFINE_STORE (PackedB_ColumnMajor , _SG16 , short , int , 8 , 32 , COL_MAJOR , , 4 , false)
12001220DEFINE_STORE (PackedB_PackedB , _SG16 , short , int , 8 , 32 , ROW_MAJOR , , 4 , true)
12011221
12021222// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
@@ -1209,6 +1229,7 @@ DEFINE_STORE(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 8, false
12091229DEFINE_STORE_AND_CHECKED (PackedB_PackedB , _SG16 , char , int , 8 , 64 , ROW_MAJOR , , 8 , false)
12101230
12111231/* PackedB store i8 SG16 for subgroup 32*/
1232+ DEFINE_STORE (PackedB_ColumnMajor , _SG16 , char , int , 8 , 64 , COL_MAJOR , , 4 , false)
12121233DEFINE_STORE (PackedB_PackedB , _SG16 , char , int , 8 , 64 , ROW_MAJOR , , 4 , true)
12131234
12141235/* B store tf32 SG16 */
0 commit comments