@@ -152,7 +152,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
152152 int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
153153 int pitch = width ; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
154154 int height = M - 1 ; /* row count */ \
155- long x = (offset - baseoffset ) / (sizeof (element_type )); /* in elements */ \
155+ long x = (offset - baseoffset ) / (sizeof (contrib_type )); /* in elements */ \
156156 int2 coords = (int2 )(x , 0 ); \
157157 OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_RW_NAME(read, contrib_bitwidth, M, K)(long, int, int, int, int2); \
158158 OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_RW_NAME(read, contrib_bitwidth, M, K)(baseoffset, width, height, pitch, coords); \
@@ -164,7 +164,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
164164 int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
165165 int pitch = width ; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
166166 int height = M - 1 ; /* row count */ \
167- long x = (offset - baseoffset ) / (sizeof (element_type )); /* in elements */ \
167+ long x = (offset - baseoffset ) / (sizeof (contrib_type )); /* in elements */ \
168168 int2 coords = (int2 )(x , 0 ); \
169169 void DEFINE_BLOCK2D_RW_NAME (write , contrib_bitwidth , M , K )(long , int , int , int , int2 , OUT_VEC ##M (u##contrib_type)); \
170170 OUT_VEC##M(u##contrib_type) val = VEC_TO_VEC##M(u##contrib_type, vec); \
@@ -304,7 +304,7 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJO
304304// set block_opt to false to disable block non-continous optimization per one built-in as a workaround
305305#define DEFINE_STORE (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , stride_opt , block_opt ) \
306306 INLINE void MANGLE_STORE_NAME(layout, sg, elem_bitwidth, shape) (char *mem, OUT_VEC##M(contrib_type) vec, int stride) { \
307- if (__JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) && order == ROW_MAJOR) { \
307+ if (__JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) && order == ROW_MAJOR && elem_bitwidth > 8 ) { \
308308 IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K, vec) \
309309 } \
310310 if (__JointMatrixLoadStoreOpt >= VECTOR_CONT_IMPL && stride == stride_opt \
@@ -330,22 +330,19 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJO
330330 }
331331
332332// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
333- DEFINE_STORE (PackedA_RowMajor , , char , 8 , int , 32 , 8 , 32 , 8 x32 , ROW_MAJOR , , 32 , false)
334-
335- // TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
336- DEFINE_STORE (PackedA_RowMajor , , short , 16 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , false)
337-
338- // TODO: investigate why intel_sub_group_block_write_us causes an assertion and enable blocked non-continuous optimization
339- DEFINE_STORE (PackedA_RowMajor , _SG16 , char , 8 , short , 16 , 8 , 32 , 8 x32 , ROW_MAJOR , _us , 32 , false)
340-
341- // TODO: investigate why intel_sub_group_block_write_us causes an assertion and enable blocked non-continuous optimization
333+ DEFINE_STORE (PackedA_RowMajor , , char , 8 , int , 32 , 8 , 32 , 8 x32 , ROW_MAJOR , , 32 , false)
334+ DEFINE_STORE (PackedA_RowMajor , , short , 16 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , false)
335+ DEFINE_STORE (PackedA_RowMajor , _SG16 , char , 8 , short , 16 , 8 , 32 , 8 x32 , ROW_MAJOR , _us , 32 , false)
342336DEFINE_STORE (PackedA_RowMajor , _SG16 , short , 16 , short , 16 , 8 , 16 , 8 x16 , ROW_MAJOR , _us , 16 , false)
343337
344- DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x8 , ROW_MAJOR , , 16 , true)
345- DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x16 , ROW_MAJOR , , 32 , true)
338+ DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x8 , ROW_MAJOR , , 16 , true)
339+ DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x16 , ROW_MAJOR , , 32 , true)
340+ DEFINE_STORE (PackedB_PackedB , _SG16 , short , 16 , int , 32 , 8 , 16 , 16 x8 , ROW_MAJOR , , 16 , true)
341+ DEFINE_STORE (PackedB_PackedB , _SG16 , short , 16 , int , 32 , 8 , 16 , 16 x16 , ROW_MAJOR , , 32 , true)
346342
347343// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
348- DEFINE_STORE (PackedB_PackedB , , char , 8 , int , 32 , 8 , 32 , 32 x8 , ROW_MAJOR , , 16 , false)
344+ DEFINE_STORE (PackedB_PackedB , , char , 8 , int , 32 , 8 , 32 , 32 x8 , ROW_MAJOR , , 16 , false)
345+ DEFINE_STORE (PackedB_PackedB , _SG16 , char , 8 , int , 32 , 8 , 32 , 32 x8 , ROW_MAJOR , , 16 , false)
349346
350347DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , ROW_MAJOR , , 8 , true)
351348DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 7 , 8 , 7 x8 , ROW_MAJOR , , 8 , true)
@@ -356,7 +353,7 @@ DEFINE_STORE(Accumulator_RowMajor, , int, 32, int, 32, 3, 8, 3x8, ROW_MAJOR, , 8
356353DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 2 , 8 , 2 x8 , ROW_MAJOR , , 8 , true)
357354DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 1 , 8 , 1 x8 , ROW_MAJOR , , 8 , true)
358355
359- DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , true)
356+ DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , true)
360357
361358DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , true)
362359DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , ROW_MAJOR , , 16 , true)
@@ -367,4 +364,5 @@ DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJ
367364DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , ROW_MAJOR , , 16 , true)
368365DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 1 , 16 , 1 x16 , ROW_MAJOR , , 16 , true)
369366
370- DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , -1 , false)
367+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , -1 , false)
368+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , -1 , false)
0 commit comments