Skip to content

Commit 284040c

Browse files
YuriPlyakhinigcbot
authored andcommitted
SYCL Joint Matrix transpose A/B store 8bit, 16bit
SYCL Joint Matrix transpose A/B store 8bit, 16bit
1 parent d625c19 commit 284040c

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,7 @@ DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, char, int, 8, 64, VNNI_TX,
922922
/* PackedB load i8 SG16 for sub group size 32*/
923923
DEFINE_LOAD(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 4)
924924
DEFINE_LOAD(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR, , 4)
925+
DEFINE_LOAD(PackedB_RowMajor, _SG16, char, int, 8, 64, VNNI_TX, , 4)
925926

926927
/* B load tf32 SG16 */
927928
DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, int, int, 8, 16, ROW_MAJOR, , 8)
@@ -1041,15 +1042,21 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
10411042
return; \
10421043
}
10431044

1044-
#define DEFINE_STORE_SCALAR_IMPL(element_type, contrib_type, M, K, order, WI_rows) \
1045-
contrib_type *ptr = (contrib_type *)mem; \
1045+
#define DEFINE_STORE_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, order, WI_rows) \
10461046
int slid = get_sub_group_local_id(); \
10471047
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
1048-
stride = stride / pack_factor; \
1048+
int elem_num = (M * K) / sg_size; \
10491049
int sg_cols = K / pack_factor; \
1050+
int skip_factor = sg_size / sg_cols; \
1051+
if (_##layout == _PackedA_ColumnMajor && elem_bitwidth == 8 && contrib_bitwidth == 16) { \
1052+
for (int i = 0; i < elem_num; i++) \
1053+
mem[(i % pack_factor) * stride + ((slid * pack_factor) % K) * stride + (i / pack_factor) * skip_factor + (slid * pack_factor) / K] = src[i]; \
1054+
return; \
1055+
} \
1056+
contrib_type *ptr = (contrib_type *)mem; \
1057+
stride = stride / pack_factor; \
10501058
__private contrib_type *slice = (__private contrib_type *)src; \
10511059
if(sg_size >= sg_cols) { \
1052-
int skip_factor = sg_size / sg_cols; \
10531060
for (int i = 0; i < WI_rows; i++) { \
10541061
if ( (i*skip_factor + slid/sg_cols) < M ) \
10551062
ptr[IND##order(slid, stride, skip_factor, i, sg_cols)] = slice[i]; \
@@ -1074,22 +1081,22 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, int, 2, 16, COL_MAJOR, , 1)
10741081
} else { \
10751082
DEFINE_STORE_VECTORS_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, us, WI_rows, block_opt, AS_LOCAL) \
10761083
} \
1077-
DEFINE_STORE_SCALAR_IMPL(element_type, contrib_type, M, K, _##order, WI_rows) \
1084+
DEFINE_STORE_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
10781085
}
10791086

10801087
#define DEFINE_STORE_IMPL_AS_LOCAL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt) \
10811088
INLINE void MANGLE_STORE_NAME(layout, sg, elem_bitwidth, shape, WI_rows, local) (char *mem, __private char *src, long stride, int cacheOpt) { \
10821089
int sg_size = get_sub_group_size(); \
10831090
DEFINE_STORE_VECTORS_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, us, WI_rows, block_opt, AS_LOCAL) \
1084-
DEFINE_STORE_SCALAR_IMPL(element_type, contrib_type, M, K, _##order, WI_rows) \
1091+
DEFINE_STORE_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
10851092
}
10861093

10871094
#define DEFINE_STORE_IMPL_AS_GLOBAL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt) \
10881095
INLINE void MANGLE_STORE_NAME(layout, sg, elem_bitwidth, shape, WI_rows, global) (char *mem, __private char *src, long stride, int cacheOpt) { \
10891096
int sg_size = get_sub_group_size(); \
10901097
DEFINE_STORE_BLOCK2D_IMPL(sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
10911098
DEFINE_STORE_VECTORS_IMPL(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, us, WI_rows, block_opt, AS_GLOBAL) \
1092-
DEFINE_STORE_SCALAR_IMPL(element_type, contrib_type, M, K, _##order, WI_rows) \
1099+
DEFINE_STORE_SCALAR_IMPL(layout, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, _##order, WI_rows) \
10931100
}
10941101

10951102
#define DEFINE_STORE_CHECKED_IMPL(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, WI_rows, block_opt) \
@@ -1149,6 +1156,9 @@ DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 6, 32, ROW_MAJOR, _us, 6, fa
11491156
DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 7, 32, ROW_MAJOR, _us, 7, false)
11501157
DEFINE_STORE_AND_CHECKED(PackedA_RowMajor, _SG16, char, short, 8, 32, ROW_MAJOR, _us, 8, false)
11511158

1159+
/* PackedA store i8 SG16 Col Major*/
1160+
DEFINE_STORE(PackedA_ColumnMajor, _SG16, char, short, 8, 32, COL_MAJOR, _us, 8, true)
1161+
11521162
/* PackedA store i8 SG16 for subgroup 32*/
11531163
// DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 1, 32, ROW_MAJOR, _us, 1, false) same as for subgroup 16
11541164
DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 2, 32, ROW_MAJOR, _us, 1, false)
@@ -1159,6 +1169,9 @@ DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 6, 32, ROW_MAJOR, _us, 3, fa
11591169
DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 7, 32, ROW_MAJOR, _us, 4, false)
11601170
DEFINE_STORE(PackedA_RowMajor, _SG16, char, short, 8, 32, ROW_MAJOR, _us, 4, false)
11611171

1172+
/* PackedA store i8 SG16 Col Major for sg 32*/
1173+
DEFINE_STORE(PackedA_ColumnMajor, _SG16, char, short, 8, 32, COL_MAJOR, _us, 4, true)
1174+
11621175
/* PackedA store i16 SG16 */
11631176
DEFINE_STORE_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR, _us, 1, false)
11641177
DEFINE_STORE_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 2, 16, ROW_MAJOR, _us, 2, false)
@@ -1171,6 +1184,12 @@ DEFINE_STORE_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 8, 16, ROW_MAJOR
11711184

11721185
DEFINE_STORE_AND_CHECKED(PackedA_RowMajor, _SG16, short, short, 1, 32, ROW_MAJOR, _us, 2, true)
11731186

1187+
/* PackedA store i16 SG16 Col Major*/
1188+
DEFINE_STORE(PackedA_ColumnMajor, _SG16, short, short, 8, 16, COL_MAJOR, _us, 8, true)
1189+
1190+
/* PackedA store i16 SG16 Col Major for sg size 32*/
1191+
DEFINE_STORE(PackedA_ColumnMajor, _SG16, short, short, 8, 16, COL_MAJOR, _us, 4, true)
1192+
11741193
/* PackedA store i16 SG16 for sub group size 32 */
11751194
// DEFINE_STORE(PackedA_RowMajor, _SG16, short, short, 1, 16, ROW_MAJOR, _us, 1, false) same as for subgroup 16
11761195
DEFINE_STORE(PackedA_RowMajor, _SG16, short, short, 2, 16, ROW_MAJOR, _us, 1, false)
@@ -1197,6 +1216,7 @@ DEFINE_STORE_AND_CHECKED(PackedB_PackedB, _SG16, short, int, 8, 32, ROW_MAJO
11971216
DEFINE_STORE(PackedB_RowMajor, _SG16, short, int, 8, 32, VNNI_TX, , 8, true)
11981217

11991218
/* PackedB store i16 SG16 for subgroup 32*/
1219+
DEFINE_STORE(PackedB_ColumnMajor, _SG16, short, int, 8, 32, COL_MAJOR, , 4, false)
12001220
DEFINE_STORE(PackedB_PackedB, _SG16, short, int, 8, 32, ROW_MAJOR, , 4, true)
12011221

12021222
// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
@@ -1209,6 +1229,7 @@ DEFINE_STORE(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 8, false
12091229
DEFINE_STORE_AND_CHECKED(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR, , 8, false)
12101230

12111231
/* PackedB store i8 SG16 for subgroup 32*/
1232+
DEFINE_STORE(PackedB_ColumnMajor, _SG16, char, int, 8, 64, COL_MAJOR, , 4, false)
12121233
DEFINE_STORE(PackedB_PackedB, _SG16, char, int, 8, 64, ROW_MAJOR, , 4, true)
12131234

12141235
/* B store tf32 SG16 */

0 commit comments

Comments
 (0)