-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] Make 1-D memref flattening a prerequisite for vector narrow type emulation #157771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Alan Li (lialan) ChangesAddresses: #115653 We already have utilities to flatten memrefs into 1-D. This change makes memref flattening a prerequisite for vector narrow type emulation, ensuring that emulation patterns only need to handle 1-D scenarios. Full diff: https://github.com/llvm/llvm-project/pull/157771.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 33e3d94f02b1c..e7751df724f9c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
+/// Patterns for flattening multi-dimensional memref operations into
+/// one-dimensional memref operations.
+void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 42be847811d52..d658d147a0a3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -271,12 +271,8 @@ struct FlattenMemrefsPass
} // namespace
-void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
- patterns.insert<MemRefRewritePattern<memref::LoadOp>,
- MemRefRewritePattern<memref::StoreOp>,
- MemRefRewritePattern<memref::AllocOp>,
- MemRefRewritePattern<memref::AllocaOp>,
- MemRefRewritePattern<vector::LoadOp>,
+void memref::populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns) {
+ patterns.insert<MemRefRewritePattern<vector::LoadOp>,
MemRefRewritePattern<vector::StoreOp>,
MemRefRewritePattern<vector::TransferReadOp>,
MemRefRewritePattern<vector::TransferWriteOp>,
@@ -284,3 +280,16 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<vector::MaskedStoreOp>>(
patterns.getContext());
}
+
+void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
+ patterns.insert<MemRefRewritePattern<memref::LoadOp>,
+ MemRefRewritePattern<memref::StoreOp>,
+ MemRefRewritePattern<memref::AllocOp>,
+ MemRefRewritePattern<memref::AllocaOp>>(
+ patterns.getContext());
+}
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+ populateFlattenMemrefOpsPatterns(patterns);
+ populateFlattenVectorMemrefPatterns(patterns);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f78e579d6c099..2ea17dbe2f53e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -556,7 +556,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
@@ -817,7 +816,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// ConvertVectorMaskedStore
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.maskedstore` operations on narrow element types to work
+/// with wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
+/// and storing with an adjusted mask, since each `i8` container element holds
+/// two `i4` values.
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -826,10 +831,10 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
+ // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
if (op.getValueToStore().getType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in vector.maskedstore op must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -931,18 +936,27 @@ struct ConvertVectorMaskedStore final
// ConvertVectorLoad
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.load` on narrow element types to work with
+/// wider, byte-aligned container types by adjusting load sizes and using
+/// bitcasting.
+///
+/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
+/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` and bitcasting
+/// back, since each `i8` container holds two `i4` values.
+///
+/// There are cases where the number of elements to load is not byte-aligned. In
+/// those cases, loads are converted to byte-aligned, byte-sized loads and the
+/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
- // See #115653
+ // Prerequisites: memref in the vector.load op is flattened into 1-D.
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -961,8 +975,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
- // Here only the 1-D vector load is considered, and the N-D memref types
- // should be linearized.
// For example, to emulate i4 to i8, the following op:
//
// %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,7 +1049,12 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// ConvertVectorMaskedLoad
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.maskedload` operations on narrow element types to work with
+/// wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
+/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1045,10 +1062,9 @@ struct ConvertVectorMaskedLoad final
LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
@@ -1229,7 +1245,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
int elemsPerMultiByte = multiByteBits / subByteBits;
- // TODO: This is a bit too restrictive for vectors rank > 1.
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
}
@@ -1246,10 +1261,11 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
+ // Prerequisites: memref in the vector.transfer_read op is flattened into
+ // 1-D.
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -2228,6 +2244,9 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW) {
+ // As a prerequisite, make sure memrefs in vector ops are linearized.
+ memref::populateFlattenVectorMemrefPatterns(patterns);
+
// Populate `vector.*` conversion patterns.
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Maybe we can add a test for 2D sub-byte memrefs that get flattened and emulated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
One comment - similar to what Mahesh already hinted at :)
d5b85f8
to
a216bec
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
a216bec
to
f523d84
Compare
A side note for Alan, since he's working on narrow type emulation. There was a commit that breaks IREE long time ago, and the conclusion was duplicating the pattern to IREE. iree-org/iree#20981 The root cause is that IREE has the assumption that they are always aligned. See the writeup for more details. The missing feature in the upstream is that we want to break (or provide an option) to the If we are going to improve those patterns, I think we need to do the cleanup first. Otherwise, the patterns would diverge and we'd have to maintain some of them in IREE's codebase. |
@hanhanW Agree to give choices back to downstream so people can choose what to integrate. Given that we are going down the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates - I've just noticed them and left some fly-by comments. Will take a proper look tomorrow.
Quick question - why introduce TestMemRefFlattenAndVectorNarrowTypeEmulationPass
rather than introduce a new TD Op? I am fine with either, but do prefer the latter - I often re-use them in e2e testing. I'm mostly curious whether this choice is your preference?
mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
populateFlattenVectorMemrefPatterns
includes patterns for vector.store
, vector.transfer_read
and vector.transfer_write
. Can you remind me - are these supported? If not, could you add a TODO here?
Also, could you add a high-level comment specifying what combination of patterns is tested and that we are merely verifying that narrow-type-emulation works for rank > 1 memrefs? Otherwise the lack of more thorough test lines feels a bit ad-hoc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added new tests except vector.transfer_write
, added a TODO there.
Also updated comments as well.
mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
Outdated
Show resolved
Hide resolved
@banach-space Because it needs to supply a type converter to it, and I had to define one? Honestly I just looked up the code on how people do tests and added one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've left some more comments :)
I also waned to add - this is a very welcome and much appreciated contribution, thank you Alan 🙏🏻 🙇🏻
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
Outdated
Show resolved
Hide resolved
2fe656d
to
6edd712
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % minor suggestion
Thanks again for re-visiting this 🙏🏻
mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
Outdated
Show resolved
Hide resolved
Hi there, this PR cause build failure on our staging bot: https://lab.llvm.org/staging/#/builders/105/builds/32450. Could you please take a look? Thanks! cmake cache to reproduce: https://github.com/llvm/llvm-project/blob/main/offload/cmake/caches/AMDGPUBot.cmake error:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/22258 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/23446 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/22235 Here is the relevant piece of the build log for the reference
|
The original PR broke pretty much all our bots.
Addresses: #115653
We already have utilities to flatten memrefs into 1-D. This change makes memref flattening a prerequisite for vector narrow type emulation, ensuring that emulation patterns only need to handle 1-D scenarios.