@@ -840,6 +840,31 @@ class LoadNdPattern : public OpConversionPattern<LoadNdOp> {
840
840
op, " Only global access supported for block load." );
841
841
auto payload = adaptor.getTensorDesc ();
842
842
auto retTy = op.getType ();
843
+ auto bitWidth = elemTy.getIntOrFloatBitWidth ();
844
+ if (bitWidth < 8 ) {
845
+ if (8 % bitWidth != 0 )
846
+ return rewriter.notifyMatchFailure (
847
+ op, " Only sub byte type with bit-width 1, 2, 4, or 8 are "
848
+ " supported for block load." );
849
+ auto subByteFactor = 8 / bitWidth;
850
+ // For supported sub byte type,
851
+ // fake element type to i8 and update elemTy, retTy and tdescTy
852
+ // accordingly. Add cast before and after intrinsic call to ensure the
853
+ // type matches the original type.
854
+ elemTy = rewriter.getI8Type ();
855
+ auto shape = tdescTy.getShape ().vec ();
856
+ auto lastDim = shape.size () - 1 ;
857
+ if (shape[lastDim] % subByteFactor != 0 ) {
858
+ return rewriter.notifyMatchFailure (
859
+ op, " The last dimension but be a multiple of (8 / bitWidth) for "
860
+ " sub byte types." );
861
+ }
862
+ shape[lastDim] = shape[lastDim] / subByteFactor;
863
+ tdescTy = TensorDescType::get (tdescTy.getContext (), shape, elemTy,
864
+ tdescTy.getEncoding (),
865
+ /* sg_map*/ nullptr );
866
+ retTy = VectorType::get (tdescTy.getShape (), elemTy);
867
+ }
843
868
844
869
// TODO: remove this after moving transposeBitWidth into a standalone
845
870
// pass. update the width and pictch of the payload when transposeBitWidth
@@ -908,6 +933,8 @@ class LoadNdPattern : public OpConversionPattern<LoadNdOp> {
908
933
909
934
// TODO: remove this after moving transposeBitWidth into a standalone
910
935
// pass.
936
+ // NOTE: sub byte type handling also needs the bitcast to the original
937
+ // type after the intrinsic call.
911
938
if (retTy != op.getType ()) {
912
939
auto targetTy = convertVectorType (op.getType ()).second ;
913
940
callOp = rewriter.create <vector::BitCastOp>(loc, targetTy, callOp);
@@ -959,6 +986,31 @@ class PrefetchNdPattern : public OpConversionPattern<PrefetchNdOp> {
959
986
if (scope != xegpu::MemorySpace::Global)
960
987
return rewriter.notifyMatchFailure (
961
988
op, " Only global access supported for block prefetch." );
989
+ auto elemTy = tdescTy.getElementType ();
990
+ auto bitWidth = elemTy.getIntOrFloatBitWidth ();
991
+ if (bitWidth < 8 ) {
992
+ if (8 % bitWidth != 0 )
993
+ return rewriter.notifyMatchFailure (
994
+ op, " Only sub byte type with bit-width 1, 2, 4, or 8 are "
995
+ " supported for block prefetch." );
996
+ auto subByteFactor = 8 / bitWidth;
997
+ // For supported sub byte type,
998
+ // fake element type to i8 and update elemTy, retTy and tdescTy
999
+ // accordingly. Add cast before and after intrinsic call to ensure the
1000
+ // type matches the original type.
1001
+ elemTy = rewriter.getI8Type ();
1002
+ auto shape = tdescTy.getShape ().vec ();
1003
+ auto lastDim = shape.size () - 1 ;
1004
+ if (shape[lastDim] % subByteFactor != 0 ) {
1005
+ return rewriter.notifyMatchFailure (
1006
+ op, " The last dimension but be a multiple of (8 / bitWidth) for "
1007
+ " sub byte types." );
1008
+ }
1009
+ shape[lastDim] = shape[lastDim] / subByteFactor;
1010
+ tdescTy = TensorDescType::get (tdescTy.getContext (), shape, elemTy,
1011
+ tdescTy.getEncoding (),
1012
+ /* sg_map*/ nullptr );
1013
+ }
962
1014
auto callOp = gen2DPrefetchIntrinsicCall (
963
1015
rewriter, loc, l1hint, l3hint, tdescTy, adaptor.getTensorDesc ());
964
1016
rewriter.replaceOp (op, callOp);
@@ -1010,7 +1062,33 @@ class StoreNdPattern : public OpConversionPattern<StoreNdOp> {
1010
1062
if (scope != xegpu::MemorySpace::Global)
1011
1063
return rewriter.notifyMatchFailure (
1012
1064
op, " Only global access supported for block store." );
1013
-
1065
+ auto elemTy = tdescTy.getElementType ();
1066
+ auto bitWidth = elemTy.getIntOrFloatBitWidth ();
1067
+ if (bitWidth < 8 ) {
1068
+ if (8 % bitWidth != 0 )
1069
+ return rewriter.notifyMatchFailure (
1070
+ op, " Only sub byte type with bit-width 1, 2, 4, or 8 are "
1071
+ " supported for block store." );
1072
+ auto subByteFactor = 8 / bitWidth;
1073
+ // For supported sub byte type,
1074
+ // fake element type to i8 and update elemTy, retTy and tdescTy
1075
+ // accordingly. Add cast before and after intrinsic call to ensure the
1076
+ // type matches the original type.
1077
+ elemTy = rewriter.getI8Type ();
1078
+ auto shape = tdescTy.getShape ().vec ();
1079
+ auto lastDim = shape.size () - 1 ;
1080
+ if (shape[lastDim] % subByteFactor != 0 ) {
1081
+ return rewriter.notifyMatchFailure (
1082
+ op, " The last dimension but be a multiple of (8 / bitWidth) for "
1083
+ " sub byte types." );
1084
+ }
1085
+ shape[lastDim] = shape[lastDim] / subByteFactor;
1086
+ tdescTy = TensorDescType::get (tdescTy.getContext (), shape, elemTy,
1087
+ tdescTy.getEncoding (),
1088
+ /* sg_map*/ nullptr );
1089
+ auto dataTy = VectorType::get ({tdescTy.getNumElements ()}, elemTy);
1090
+ data = rewriter.create <vector::BitCastOp>(loc, dataTy, data);
1091
+ }
1014
1092
auto callOp =
1015
1093
gen2DStoreIntrinsicCall (rewriter, loc, l1hint, l3hint, tdescTy,
1016
1094
adaptor.getTensorDesc (), data);
0 commit comments