38
38
#include < cstdint>
39
39
#include < optional>
40
40
41
+ #include " mlir/Dialect/MemRef/Transforms/Transforms.h"
42
+
41
43
using namespace mlir ;
42
44
43
45
#define DEBUG_TYPE " vector-narrow-type-emulation"
@@ -556,7 +558,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
556
558
matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
557
559
ConversionPatternRewriter &rewriter) const override {
558
560
559
- // See #115653
560
561
if (op.getValueToStore ().getType ().getRank () != 1 )
561
562
return rewriter.notifyMatchFailure (op,
562
563
" only 1-D vectors are supported ATM" );
@@ -817,7 +818,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
817
818
// ConvertVectorMaskedStore
818
819
// ===----------------------------------------------------------------------===//
819
820
820
- // TODO: Document-me
821
+ // / Converts `vector.maskedstore` operations on narrow element types to work
822
+ // / with wider, byte-aligned container types by adjusting the mask and using
823
+ // / bitcasting.
824
+ // /
825
+ // / Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
826
+ // / (each `i8` container element holds two `i4` values) and storing with an
827
+ // / adjusted mask .
821
828
struct ConvertVectorMaskedStore final
822
829
: OpConversionPattern<vector::MaskedStoreOp> {
823
830
using OpConversionPattern::OpConversionPattern;
@@ -826,10 +833,10 @@ struct ConvertVectorMaskedStore final
826
833
matchAndRewrite (vector::MaskedStoreOp op, OpAdaptor adaptor,
827
834
ConversionPatternRewriter &rewriter) const override {
828
835
829
- // See #115653
836
+ // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
830
837
if (op.getValueToStore ().getType ().getRank () != 1 )
831
- return rewriter.notifyMatchFailure (op,
832
- " only 1-D vectors are supported ATM " );
838
+ return rewriter.notifyMatchFailure (
839
+ op, " Memref in vector.maskedstore op must be flattened beforehand. " );
833
840
834
841
auto loc = op.getLoc ();
835
842
auto containerElemTy =
@@ -931,18 +938,27 @@ struct ConvertVectorMaskedStore final
931
938
// ConvertVectorLoad
932
939
// ===----------------------------------------------------------------------===//
933
940
934
- // TODO: Document-me
941
+ // / Converts `vector.load` on narrow element types to work with
942
+ // / wider, byte-aligned container types by adjusting load sizes and using
943
+ // / bitcasting.
944
+ // /
945
+ // / Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
946
+ // / by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
947
+ // / container holds two `i4` values) and bitcasting back.
948
+ // /
949
+ // / There are cases where the number of elements to load is not byte-aligned. In
950
+ // / those cases, loads are converted to byte-aligned, byte-sized loads and the
951
+ // / target vector is extracted from the loaded vector.
935
952
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
936
953
using OpConversionPattern::OpConversionPattern;
937
954
938
955
LogicalResult
939
956
matchAndRewrite (vector::LoadOp op, OpAdaptor adaptor,
940
957
ConversionPatternRewriter &rewriter) const override {
941
-
942
- // See #115653
958
+ // Prerequisite: memref in the vector.load op is flattened into 1-D.
943
959
if (op.getVectorType ().getRank () != 1 )
944
- return rewriter.notifyMatchFailure (op,
945
- " only 1-D vectors are supported ATM " );
960
+ return rewriter.notifyMatchFailure (
961
+ op, " Memref in emulated vector ops must be flattened beforehand. " );
946
962
947
963
auto loc = op.getLoc ();
948
964
auto containerElemTy =
@@ -961,8 +977,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
961
977
962
978
// Adjust the number of elements to load when emulating narrow types,
963
979
// and then cast back to the original type with vector.bitcast op.
964
- // Here only the 1-D vector load is considered, and the N-D memref types
965
- // should be linearized.
966
980
// For example, to emulate i4 to i8, the following op:
967
981
//
968
982
// %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,18 +1051,22 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
1037
1051
// ConvertVectorMaskedLoad
1038
1052
// ===----------------------------------------------------------------------===//
1039
1053
1040
- // TODO: Document-me
1054
+ // / Converts `vector.maskedload` operations on narrow element types to work with
1055
+ // / wider, byte-aligned container types by adjusting the mask and using
1056
+ // / bitcasting.
1057
+ // /
1058
+ // / Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1059
+ // / bitcasting, since each `i8` container element holds two `i4` values.
1041
1060
struct ConvertVectorMaskedLoad final
1042
1061
: OpConversionPattern<vector::MaskedLoadOp> {
1043
1062
using OpConversionPattern::OpConversionPattern;
1044
1063
1045
1064
LogicalResult
1046
1065
matchAndRewrite (vector::MaskedLoadOp op, OpAdaptor adaptor,
1047
1066
ConversionPatternRewriter &rewriter) const override {
1048
- // See #115653
1049
1067
if (op.getVectorType ().getRank () != 1 )
1050
- return rewriter.notifyMatchFailure (op,
1051
- " only 1-D vectors are supported ATM " );
1068
+ return rewriter.notifyMatchFailure (
1069
+ op, " Memref in emulated vector ops must be flattened beforehand. " );
1052
1070
1053
1071
auto loc = op.getLoc ();
1054
1072
@@ -1229,7 +1247,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
1229
1247
1230
1248
int elemsPerMultiByte = multiByteBits / subByteBits;
1231
1249
1232
- // TODO: This is a bit too restrictive for vectors rank > 1.
1233
1250
return subByteVecTy.getShape ().back () % elemsPerMultiByte == 0 ;
1234
1251
}
1235
1252
@@ -1246,10 +1263,11 @@ struct ConvertVectorTransferRead final
1246
1263
matchAndRewrite (vector::TransferReadOp op, OpAdaptor adaptor,
1247
1264
ConversionPatternRewriter &rewriter) const override {
1248
1265
1249
- // See #115653
1266
+ // Prerequisites: memref in the vector.transfer_read op is flattened into
1267
+ // 1-D.
1250
1268
if (op.getVectorType ().getRank () != 1 )
1251
- return rewriter.notifyMatchFailure (op,
1252
- " only 1-D vectors are supported ATM " );
1269
+ return rewriter.notifyMatchFailure (
1270
+ op, " Memref in emulated vector ops must be flattened beforehand. " );
1253
1271
1254
1272
auto loc = op.getLoc ();
1255
1273
auto containerElemTy =
@@ -2227,7 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
2227
2245
void vector::populateVectorNarrowTypeEmulationPatterns (
2228
2246
const arith::NarrowTypeEmulationConverter &typeConverter,
2229
2247
RewritePatternSet &patterns, bool disableAtomicRMW) {
2230
-
2231
2248
// Populate `vector.*` conversion patterns.
2232
2249
// TODO: #119553 support atomicity
2233
2250
patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
@@ -2266,3 +2283,10 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
2266
2283
RewritePatternSet &patterns, PatternBenefit benefit) {
2267
2284
patterns.add <RewriteVectorTranspose>(patterns.getContext (), benefit);
2268
2285
}
2286
+
2287
+ void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns (
2288
+ arith::NarrowTypeEmulationConverter &typeConverter,
2289
+ RewritePatternSet &patterns) {
2290
+ memref::populateFlattenVectorOpsOnMemrefPatterns (patterns);
2291
+ vector::populateVectorNarrowTypeEmulationPatterns (typeConverter, patterns);
2292
+ }
0 commit comments