@@ -421,7 +421,7 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
421421
422422#define DEFINE_BLOCK2D_RW_NAME (rw , tx , contrib_bitwidth , WI_rows , tile_height , tile_width ) __builtin_IB_subgroup_block_##rw##_flat_cacheopts##tx##_u##contrib_bitwidth##_wi##WI_rows##_m##tile_height##k##tile_width##v1
423423#define DEFINE_BLOCK2D_TRANSPOSE_NAME (contrib_bitwidth , WI_rows , tile_height , tile_width ) __builtin_IB_subgroup_block_read_flat_cacheopts_transpose_u##contrib_bitwidth##_wi##WI_rows##_m##tile_height##_k##tile_width
424- #define DEFINE_BLOCK2D_VNNI_NAME (contrib_bitwidth , tile_height ) __builtin_IB_subgroup_block_read_flat_cacheopts_transform_u##contrib_bitwidth##_k## tile_height // tile_width = sub group size (16)
424+ #define DEFINE_BLOCK2D_VNNI_NAME (contrib_bitwidth , WI_rows , tile_height , tile_width ) __builtin_IB_subgroup_block_read_flat_cacheopts_transform_u##contrib_bitwidth##_wi##WI_rows##_##k## tile_height##n## tile_width
425425
426426/* For platforms without SG16 JointMatrix support block2d is not available. The
427427 * implementation remains empty, will fallthrough to vector implementation. */
@@ -539,8 +539,8 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
539539 int height = orig_M - 1 ; /* row count */ \
540540 long x = (offset - baseoffset ) / (sizeof (element_type )); /* in elements */ \
541541 int2 coords = (int2 )(x , 0 ); \
542- OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
543- OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(baseoffset, width, height, pitch, coords, cacheOpt); \
542+ OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K )(long, int, int, int, int2, int); \
543+ OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K )(baseoffset, width, height, pitch, coords, cacheOpt); \
544544 *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
545545 return; \
546546 }
@@ -591,8 +591,8 @@ Each subgroup stores 16 of 8x16 slices. Hence, row_stride = R / 4 = 32 / 4 = 8 a
591591 int pitch = sizeof (element_type ) * stride - 1 ; /* in bytes */ \
592592 int height_size = height - 1 ; \
593593 int2 coords = (int2 )(x , y ); \
594- OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(long, int, int, int, int2, int); \
595- OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, orig_M)(offset, width_size, height_size, pitch, coords, cacheOpt); \
594+ OUT_VEC ##WI_rows (u##contrib_type) DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K )(long, int, int, int, int2, int); \
595+ OUT_VEC##WI_rows(u##contrib_type) res = DEFINE_BLOCK2D_VNNI_NAME(elem_bitwidth, WI_rows, orig_M, orig_K )(offset, width_size, height_size, pitch, coords, cacheOpt); \
596596 *(__private OUT_VEC##WI_rows(u##contrib_type) *)dst = res; \
597597 return;
598598
@@ -902,6 +902,7 @@ DEFINE_LOAD_AND_CHECKED(PackedB_RowMajor, _SG16, short, int, 16, 32, VNNI_TX,
902902/* PackedB load i16 SG16 for sub group size = 32*/
903903DEFINE_LOAD (PackedB_ColumnMajor , _SG16 , short , int , 8 , 32 , COL_MAJOR , , 4 )
904904DEFINE_LOAD (PackedB_PackedB , _SG16 , short , int , 8 , 32 , ROW_MAJOR , , 4 )
905+ DEFINE_LOAD (PackedB_RowMajor , _SG16 , short , int , 8 , 32 , VNNI_TX , , 4 )
905906
906907/* PackedB load i8 SG16*/
907908DEFINE_LOAD_AND_CHECKED (PackedB_ColumnMajor , _SG16 , char , int , 8 , 64 , COL_MAJOR , , 8 )
0 commit comments