3838#include < cstdint>
3939#include < optional>
4040
41+ #include " mlir/Dialect/MemRef/Transforms/Transforms.h"
42+
4143using namespace mlir ;
4244
4345#define DEBUG_TYPE " vector-narrow-type-emulation"
@@ -556,7 +558,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
556558 matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
557559 ConversionPatternRewriter &rewriter) const override {
558560
559- // See #115653
560561 if (op.getValueToStore ().getType ().getRank () != 1 )
561562 return rewriter.notifyMatchFailure (op,
562563 " only 1-D vectors are supported ATM" );
@@ -817,7 +818,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
817818// ConvertVectorMaskedStore
818819// ===----------------------------------------------------------------------===//
819820
820- // TODO: Document-me
821+ // / Converts `vector.maskedstore` operations on narrow element types to work
822+ // / with wider, byte-aligned container types by adjusting the mask and using
823+ // / bitcasting.
824+ // /
825+ // / Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
826+ // / (each `i8` container element holds two `i4` values) and storing with an
827+ // / adjusted mask .
821828struct ConvertVectorMaskedStore final
822829 : OpConversionPattern<vector::MaskedStoreOp> {
823830 using OpConversionPattern::OpConversionPattern;
@@ -826,10 +833,10 @@ struct ConvertVectorMaskedStore final
826833 matchAndRewrite (vector::MaskedStoreOp op, OpAdaptor adaptor,
827834 ConversionPatternRewriter &rewriter) const override {
828835
829- // See #115653
836+ // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
830837 if (op.getValueToStore ().getType ().getRank () != 1 )
831- return rewriter.notifyMatchFailure (op,
832- " only 1-D vectors are supported ATM " );
838+ return rewriter.notifyMatchFailure (
839+ op, " Memref in vector.maskedstore op must be flattened beforehand. " );
833840
834841 auto loc = op.getLoc ();
835842 auto containerElemTy =
@@ -931,18 +938,27 @@ struct ConvertVectorMaskedStore final
931938// ConvertVectorLoad
932939// ===----------------------------------------------------------------------===//
933940
934- // TODO: Document-me
941+ // / Converts `vector.load` on narrow element types to work with
942+ // / wider, byte-aligned container types by adjusting load sizes and using
943+ // / bitcasting.
944+ // /
945+ // / Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
946+ // / by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
947+ // / container holds two `i4` values) and bitcasting back.
948+ // /
949+ // / There are cases where the number of elements to load is not byte-aligned. In
950+ // / those cases, loads are converted to byte-aligned, byte-sized loads and the
951+ // / target vector is extracted from the loaded vector.
935952struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
936953 using OpConversionPattern::OpConversionPattern;
937954
938955 LogicalResult
939956 matchAndRewrite (vector::LoadOp op, OpAdaptor adaptor,
940957 ConversionPatternRewriter &rewriter) const override {
941-
942- // See #115653
958+ // Prerequisite: memref in the vector.load op is flattened into 1-D.
943959 if (op.getVectorType ().getRank () != 1 )
944- return rewriter.notifyMatchFailure (op,
945- " only 1-D vectors are supported ATM " );
960+ return rewriter.notifyMatchFailure (
961+ op, " Memref in emulated vector ops must be flattened beforehand. " );
946962
947963 auto loc = op.getLoc ();
948964 auto containerElemTy =
@@ -961,8 +977,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
961977
962978 // Adjust the number of elements to load when emulating narrow types,
963979 // and then cast back to the original type with vector.bitcast op.
964- // Here only the 1-D vector load is considered, and the N-D memref types
965- // should be linearized.
966980 // For example, to emulate i4 to i8, the following op:
967981 //
968982 // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,18 +1051,22 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
10371051// ConvertVectorMaskedLoad
10381052// ===----------------------------------------------------------------------===//
10391053
1040- // TODO: Document-me
1054+ // / Converts `vector.maskedload` operations on narrow element types to work with
1055+ // / wider, byte-aligned container types by adjusting the mask and using
1056+ // / bitcasting.
1057+ // /
1058+ // / Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
1059+ // / bitcasting, since each `i8` container element holds two `i4` values.
10411060struct ConvertVectorMaskedLoad final
10421061 : OpConversionPattern<vector::MaskedLoadOp> {
10431062 using OpConversionPattern::OpConversionPattern;
10441063
10451064 LogicalResult
10461065 matchAndRewrite (vector::MaskedLoadOp op, OpAdaptor adaptor,
10471066 ConversionPatternRewriter &rewriter) const override {
1048- // See #115653
10491067 if (op.getVectorType ().getRank () != 1 )
1050- return rewriter.notifyMatchFailure (op,
1051- " only 1-D vectors are supported ATM " );
1068+ return rewriter.notifyMatchFailure (
1069+ op, " Memref in emulated vector ops must be flattened beforehand. " );
10521070
10531071 auto loc = op.getLoc ();
10541072
@@ -1229,7 +1247,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
12291247
12301248 int elemsPerMultiByte = multiByteBits / subByteBits;
12311249
1232- // TODO: This is a bit too restrictive for vectors rank > 1.
12331250 return subByteVecTy.getShape ().back () % elemsPerMultiByte == 0 ;
12341251}
12351252
@@ -1246,10 +1263,11 @@ struct ConvertVectorTransferRead final
12461263 matchAndRewrite (vector::TransferReadOp op, OpAdaptor adaptor,
12471264 ConversionPatternRewriter &rewriter) const override {
12481265
1249- // See #115653
1266+ // Prerequisites: memref in the vector.transfer_read op is flattened into
1267+ // 1-D.
12501268 if (op.getVectorType ().getRank () != 1 )
1251- return rewriter.notifyMatchFailure (op,
1252- " only 1-D vectors are supported ATM " );
1269+ return rewriter.notifyMatchFailure (
1270+ op, " Memref in emulated vector ops must be flattened beforehand. " );
12531271
12541272 auto loc = op.getLoc ();
12551273 auto containerElemTy =
@@ -2227,7 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
22272245void vector::populateVectorNarrowTypeEmulationPatterns (
22282246 const arith::NarrowTypeEmulationConverter &typeConverter,
22292247 RewritePatternSet &patterns, bool disableAtomicRMW) {
2230-
22312248 // Populate `vector.*` conversion patterns.
22322249 // TODO: #119553 support atomicity
22332250 patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
@@ -2266,3 +2283,10 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
22662283 RewritePatternSet &patterns, PatternBenefit benefit) {
22672284 patterns.add <RewriteVectorTranspose>(patterns.getContext (), benefit);
22682285}
2286+
2287+ void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns (
2288+ arith::NarrowTypeEmulationConverter &typeConverter,
2289+ RewritePatternSet &patterns) {
2290+ memref::populateFlattenVectorOpsOnMemrefPatterns (patterns);
2291+ vector::populateVectorNarrowTypeEmulationPatterns (typeConverter, patterns);
2292+ }
0 commit comments