-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][AMDGPU] Redirect transfer read to masked load lowering #146705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: jerryyin <[email protected]>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-amdgpu Author: Zhuoran Yin (jerryyin) ChangesThis PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This allows simplified lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline. Patch is 29.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146705.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a52ee2ee89caf..cc2f543e79f69 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -23,7 +23,7 @@ namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
-#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 0e858108acf35..8d0e6829ab0cc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
];
}
-def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
- let summary = "Lower the operations from the vector transfer_read to vector load";
+def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
+ let summary = "Lower the operations from the vector maskedload to vector load";
let description = [{
This pass creates a transfer read op lowering optimization. The lowering
will produce a conditional check at runtime. If within bounds, a vector
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8709a27e0168e..17bbe54ea6c0c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
ResolveStridedMetadata.cpp
- TransferReadToLoad.cpp
+ MaskedloadToLoad.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
new file mode 100644
index 0000000000000..9a368f372c296
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -0,0 +1,167 @@
+//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
+/// and `arith.select` if the memref is in buffer address space.
+static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
+ vector::MaskedLoadOp maskedOp) {
+ auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+ if (!memRefType)
+ return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+
+ Attribute addrSpace = memRefType.getMemorySpace();
+ if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
+ return rewriter.notifyMatchFailure(maskedOp, "no address space");
+
+ if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+ amdgpu::AddressSpace::FatRawBuffer)
+ return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+
+ return success();
+}
+
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+ vector::MaskedLoadOp maskedOp) {
+ VectorType vectorType = maskedOp.getVectorType();
+ Value load = builder.create<vector::LoadOp>(
+ loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(
+ loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
+ return res;
+}
+
+static constexpr char kMaskedloadNeedsMask[] =
+ "amdgpu.buffer_maskedload_needs_mask";
+
+namespace {
+
+struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
+ PatternRewriter &rewriter) const override {
+ if (maskedOp->hasAttr(kMaskedloadNeedsMask))
+ return failure();
+
+ if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+ return failure();
+ }
+
+ Location loc = maskedOp.getLoc();
+ Value src = maskedOp.getBase();
+
+ VectorType vectorType = maskedOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+ SmallVector<OpFoldResult> indices = maskedOp.getIndices();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ SmallVector<OpFoldResult> strides =
+ stridedMetadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
+ OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+ memref::LinearizedMemRefInfo linearizedInfo;
+ OpFoldResult linearizedIndices;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+ elementBitWidth, offset, sizes,
+ strides, indices);
+
+ // delta = bufferSize - linearizedOffset
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value linearIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ Value totalSize = getValueOrCreateConstantIndexOp(
+ rewriter, loc, linearizedInfo.linearizedSize);
+ Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+ // 1) check if delta < vectorSize
+ Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+
+ // 2) check if (detla % elements_per_word != 0)
+ Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+ loc, llvm::divideCeil(32, elementBitWidth));
+ Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne,
+ rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // We take the fallback of maskedload default lowering only it is both
+ // out-of-bounds and not word aligned. The fallback ensures correct results
+ // when loading at the boundary of the buffer since buffer load returns
+ // inconsistent zeros for the whole word when boundary is crossed.
+ Value ifCondition =
+ rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ Operation *read = builder.clone(*maskedOp.getOperation());
+ read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
+ Value readResult = read->getResult(0);
+ builder.create<scf::YieldOp>(loc, readResult);
+ };
+
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
+ auto ifOp =
+ rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+
+ rewriter.replaceOp(maskedOp, ifOp);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+}
+
+struct AmdgpuMaskedloadToLoadPass final
+ : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateAmdgpuMaskedloadToLoadPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
deleted file mode 100644
index f5b12a9524cc9..0000000000000
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
-
-/// This pattern supports lowering of:
-/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
-/// `vector.broadcast` if all of the following hold:
-/// - The transfer op is masked.
-/// - The memref is in buffer address space.
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-/// result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
-/// pass.
-static LogicalResult transferPreconditions(
- PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
- bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
- if (!xferOp.getMask())
- return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
-
- // Permutations are handled by VectorToSCF or
- // populateVectorTransferPermutationMapLoweringPatterns.
- // We let the 0-d corner case pass-through as it is supported.
- SmallVector<unsigned> broadcastedDims;
- if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
- &broadcastedDims))
- return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
-
- auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
- if (!memRefType)
- return rewriter.notifyMatchFailure(xferOp, "not a memref source");
-
- Attribute addrSpace = memRefType.getMemorySpace();
- if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
- return rewriter.notifyMatchFailure(xferOp, "no address space");
-
- if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
- amdgpu::AddressSpace::FatRawBuffer)
- return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
-
- // Non-unit strides are handled by VectorToSCF.
- if (!memRefType.isLastDimUnitStride())
- return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
-
- if (memRefType.getElementTypeBitWidth() < 8)
- return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
-
- // If there is broadcasting involved then we first load the unbroadcasted
- // vector, and then broadcast it with `vector.broadcast`.
- ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
- SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
- for (unsigned i : broadcastedDims)
- unbroadcastedVectorShape[i] = 1;
- unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
- unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
- requiresBroadcasting = !broadcastedDims.empty();
-
- // `vector.load` supports vector types as memref's elements only when the
- // resulting vector type is the same as the element type.
- auto memrefElTy = memRefType.getElementType();
- if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
- return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
-
- // Otherwise, element types of the memref and the vector must match.
- if (!isa<VectorType>(memrefElTy) &&
- memrefElTy != xferOp.getVectorType().getElementType())
- return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
-
- // Out-of-bounds dims are handled by MaterializeTransferMask.
- if (xferOp.hasOutOfBoundsDim())
- return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
-
- if (xferOp.getVectorType().getRank() != 1)
- // vector.maskedload operates on 1-D vectors.
- return rewriter.notifyMatchFailure(
- xferOp, "vector type is not rank 1, can't create masked load, needs "
- "VectorToSCF");
-
- return success();
-}
-
-static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
- vector::TransferReadOp readOp,
- bool requiresBroadcasting,
- VectorType unbroadcastedVectorType) {
- Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = builder.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
- Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
- }
- return res;
-}
-
-static constexpr char kTransferReadNeedsMask[] =
- "amdgpu.buffer_transfer_read_needs_mask";
-
-namespace {
-
-struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
- PatternRewriter &rewriter) const override {
- if (readOp->hasAttr(kTransferReadNeedsMask))
- return failure();
-
- bool requiresBroadcasting = false;
- VectorType unbroadcastedVectorType;
- if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
- unbroadcastedVectorType))) {
- return failure();
- }
-
- Location loc = readOp.getLoc();
- Value src = readOp.getBase();
-
- VectorType vectorType = readOp.getVectorType();
- int64_t vectorSize = vectorType.getNumElements();
- int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
- SmallVector<OpFoldResult> indices = readOp.getIndices();
-
- auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
- SmallVector<OpFoldResult> strides =
- stridedMetadata.getConstifiedMixedStrides();
- SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
- OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
- memref::LinearizedMemRefInfo linearizedInfo;
- OpFoldResult linearizedIndices;
- std::tie(linearizedInfo, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
- elementBitWidth, offset, sizes,
- strides, indices);
-
- // delta = bufferSize - linearizedOffset
- Value vectorSizeOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
- Value linearIndex =
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
- Value totalSize = getValueOrCreateConstantIndexOp(
- rewriter, loc, linearizedInfo.linearizedSize);
- Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
-
- // 1) check if delta < vectorSize
- Value isOutofBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
-
- // 2) check if (detla % elements_per_word != 0)
- Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::divideCeil(32, elementBitWidth));
- Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne,
- rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
- // We take the fallback of transfer_read default lowering only it is both
- // out-of-bounds and not word aligned. The fallback ensures correct results
- // when loading at the boundary of the buffer since buffer load returns
- // inconsistent zeros for the whole word when boundary is crossed.
- Value ifCondition =
- rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
-
- auto thenBuilder = [&](OpBuilder &builder, Location loc) {
- Operation *read = builder.clone(*readOp.getOperation());
- read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
- Value readResult = read->getResult(0);
- builder.create<scf::YieldOp>(loc, readResult);
- };
-
- auto elseBuilder = [&](OpBuilder &builder, Location loc) {
- Value res = createVectorLoadForMaskedLoad(
- builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
- rewriter.create<scf::YieldOp>(loc, res);
- };
-
- auto ifOp =
- rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
-
- rewriter.replaceOp(readOp, ifOp);
-
- return success();
- }
-};
-
-} // namespace
-
-void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<TransferReadLowering>(patterns.getContext(), benefit);
-}
-
-struct AmdgpuTransferReadToLoadPass final
- : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
- AmdgpuTransferReadToLoadPass> {
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateAmdgpuTransferReadToLoadPatterns(patterns);
- if (failed(applyPatter...
[truncated]
|
|
@llvm/pr-subscribers-backend-amdgpu Author: Zhuoran Yin (jerryyin) ChangesThis PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This allows simplified lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline. Patch is 29.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146705.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a52ee2ee89caf..cc2f543e79f69 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -23,7 +23,7 @@ namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
-#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 0e858108acf35..8d0e6829ab0cc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
];
}
-def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
- let summary = "Lower the operations from the vector transfer_read to vector load";
+def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
+ let summary = "Lower the operations from the vector maskedload to vector load";
let description = [{
This pass creates a transfer read op lowering optimization. The lowering
will produce a conditional check at runtime. If within bounds, a vector
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8709a27e0168e..17bbe54ea6c0c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
ResolveStridedMetadata.cpp
- TransferReadToLoad.cpp
+ MaskedloadToLoad.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
new file mode 100644
index 0000000000000..9a368f372c296
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -0,0 +1,167 @@
+//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
+/// and `arith.select` if the memref is in buffer address space.
+static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
+ vector::MaskedLoadOp maskedOp) {
+ auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+ if (!memRefType)
+ return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+
+ Attribute addrSpace = memRefType.getMemorySpace();
+ if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
+ return rewriter.notifyMatchFailure(maskedOp, "no address space");
+
+ if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+ amdgpu::AddressSpace::FatRawBuffer)
+ return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+
+ return success();
+}
+
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+ vector::MaskedLoadOp maskedOp) {
+ VectorType vectorType = maskedOp.getVectorType();
+ Value load = builder.create<vector::LoadOp>(
+ loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(
+ loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
+ return res;
+}
+
+static constexpr char kMaskedloadNeedsMask[] =
+ "amdgpu.buffer_maskedload_needs_mask";
+
+namespace {
+
+struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
+ PatternRewriter &rewriter) const override {
+ if (maskedOp->hasAttr(kMaskedloadNeedsMask))
+ return failure();
+
+ if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+ return failure();
+ }
+
+ Location loc = maskedOp.getLoc();
+ Value src = maskedOp.getBase();
+
+ VectorType vectorType = maskedOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+ SmallVector<OpFoldResult> indices = maskedOp.getIndices();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ SmallVector<OpFoldResult> strides =
+ stridedMetadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
+ OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+ memref::LinearizedMemRefInfo linearizedInfo;
+ OpFoldResult linearizedIndices;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+ elementBitWidth, offset, sizes,
+ strides, indices);
+
+ // delta = bufferSize - linearizedOffset
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value linearIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ Value totalSize = getValueOrCreateConstantIndexOp(
+ rewriter, loc, linearizedInfo.linearizedSize);
+ Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+ // 1) check if delta < vectorSize
+ Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+
+ // 2) check if (detla % elements_per_word != 0)
+ Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+ loc, llvm::divideCeil(32, elementBitWidth));
+ Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne,
+ rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // We take the fallback of maskedload default lowering only it is both
+ // out-of-bounds and not word aligned. The fallback ensures correct results
+ // when loading at the boundary of the buffer since buffer load returns
+ // inconsistent zeros for the whole word when boundary is crossed.
+ Value ifCondition =
+ rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ Operation *read = builder.clone(*maskedOp.getOperation());
+ read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
+ Value readResult = read->getResult(0);
+ builder.create<scf::YieldOp>(loc, readResult);
+ };
+
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
+ auto ifOp =
+ rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+
+ rewriter.replaceOp(maskedOp, ifOp);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+}
+
+struct AmdgpuMaskedloadToLoadPass final
+ : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateAmdgpuMaskedloadToLoadPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
deleted file mode 100644
index f5b12a9524cc9..0000000000000
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
-
-/// This pattern supports lowering of:
-/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
-/// `vector.broadcast` if all of the following hold:
-/// - The transfer op is masked.
-/// - The memref is in buffer address space.
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-/// result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
-/// pass.
-static LogicalResult transferPreconditions(
- PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
- bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
- if (!xferOp.getMask())
- return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
-
- // Permutations are handled by VectorToSCF or
- // populateVectorTransferPermutationMapLoweringPatterns.
- // We let the 0-d corner case pass-through as it is supported.
- SmallVector<unsigned> broadcastedDims;
- if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
- &broadcastedDims))
- return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
-
- auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
- if (!memRefType)
- return rewriter.notifyMatchFailure(xferOp, "not a memref source");
-
- Attribute addrSpace = memRefType.getMemorySpace();
- if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
- return rewriter.notifyMatchFailure(xferOp, "no address space");
-
- if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
- amdgpu::AddressSpace::FatRawBuffer)
- return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
-
- // Non-unit strides are handled by VectorToSCF.
- if (!memRefType.isLastDimUnitStride())
- return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
-
- if (memRefType.getElementTypeBitWidth() < 8)
- return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
-
- // If there is broadcasting involved then we first load the unbroadcasted
- // vector, and then broadcast it with `vector.broadcast`.
- ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
- SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
- for (unsigned i : broadcastedDims)
- unbroadcastedVectorShape[i] = 1;
- unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
- unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
- requiresBroadcasting = !broadcastedDims.empty();
-
- // `vector.load` supports vector types as memref's elements only when the
- // resulting vector type is the same as the element type.
- auto memrefElTy = memRefType.getElementType();
- if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
- return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
-
- // Otherwise, element types of the memref and the vector must match.
- if (!isa<VectorType>(memrefElTy) &&
- memrefElTy != xferOp.getVectorType().getElementType())
- return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
-
- // Out-of-bounds dims are handled by MaterializeTransferMask.
- if (xferOp.hasOutOfBoundsDim())
- return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
-
- if (xferOp.getVectorType().getRank() != 1)
- // vector.maskedload operates on 1-D vectors.
- return rewriter.notifyMatchFailure(
- xferOp, "vector type is not rank 1, can't create masked load, needs "
- "VectorToSCF");
-
- return success();
-}
-
-static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
- vector::TransferReadOp readOp,
- bool requiresBroadcasting,
- VectorType unbroadcastedVectorType) {
- Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = builder.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
- Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
- }
- return res;
-}
-
-static constexpr char kTransferReadNeedsMask[] =
- "amdgpu.buffer_transfer_read_needs_mask";
-
-namespace {
-
-struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
- PatternRewriter &rewriter) const override {
- if (readOp->hasAttr(kTransferReadNeedsMask))
- return failure();
-
- bool requiresBroadcasting = false;
- VectorType unbroadcastedVectorType;
- if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
- unbroadcastedVectorType))) {
- return failure();
- }
-
- Location loc = readOp.getLoc();
- Value src = readOp.getBase();
-
- VectorType vectorType = readOp.getVectorType();
- int64_t vectorSize = vectorType.getNumElements();
- int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
- SmallVector<OpFoldResult> indices = readOp.getIndices();
-
- auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
- SmallVector<OpFoldResult> strides =
- stridedMetadata.getConstifiedMixedStrides();
- SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
- OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
- memref::LinearizedMemRefInfo linearizedInfo;
- OpFoldResult linearizedIndices;
- std::tie(linearizedInfo, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
- elementBitWidth, offset, sizes,
- strides, indices);
-
- // delta = bufferSize - linearizedOffset
- Value vectorSizeOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
- Value linearIndex =
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
- Value totalSize = getValueOrCreateConstantIndexOp(
- rewriter, loc, linearizedInfo.linearizedSize);
- Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
-
- // 1) check if delta < vectorSize
- Value isOutofBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
-
- // 2) check if (detla % elements_per_word != 0)
- Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::divideCeil(32, elementBitWidth));
- Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne,
- rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
- // We take the fallback of transfer_read default lowering only it is both
- // out-of-bounds and not word aligned. The fallback ensures correct results
- // when loading at the boundary of the buffer since buffer load returns
- // inconsistent zeros for the whole word when boundary is crossed.
- Value ifCondition =
- rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
-
- auto thenBuilder = [&](OpBuilder &builder, Location loc) {
- Operation *read = builder.clone(*readOp.getOperation());
- read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
- Value readResult = read->getResult(0);
- builder.create<scf::YieldOp>(loc, readResult);
- };
-
- auto elseBuilder = [&](OpBuilder &builder, Location loc) {
- Value res = createVectorLoadForMaskedLoad(
- builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
- rewriter.create<scf::YieldOp>(loc, res);
- };
-
- auto ifOp =
- rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
-
- rewriter.replaceOp(readOp, ifOp);
-
- return success();
- }
-};
-
-} // namespace
-
-void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<TransferReadLowering>(patterns.getContext(), benefit);
-}
-
-struct AmdgpuTransferReadToLoadPass final
- : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
- AmdgpuTransferReadToLoadPass> {
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateAmdgpuTransferReadToLoadPatterns(patterns);
- if (failed(applyPatter...
[truncated]
|
|
@llvm/pr-subscribers-mlir-gpu Author: Zhuoran Yin (jerryyin) ChangesThis PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This allows simplified lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline. Patch is 29.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146705.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a52ee2ee89caf..cc2f543e79f69 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -23,7 +23,7 @@ namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
-#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
-void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 0e858108acf35..8d0e6829ab0cc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
];
}
-def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
- let summary = "Lower the operations from the vector transfer_read to vector load";
+def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
+ let summary = "Lower the operations from the vector maskedload to vector load";
let description = [{
This pass creates a transfer read op lowering optimization. The lowering
will produce a conditional check at runtime. If within bounds, a vector
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8709a27e0168e..17bbe54ea6c0c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
ResolveStridedMetadata.cpp
- TransferReadToLoad.cpp
+ MaskedloadToLoad.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
new file mode 100644
index 0000000000000..9a368f372c296
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -0,0 +1,167 @@
+//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
+/// and `arith.select` if the memref is in buffer address space.
+static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
+ vector::MaskedLoadOp maskedOp) {
+ auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+ if (!memRefType)
+ return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+
+ Attribute addrSpace = memRefType.getMemorySpace();
+ if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
+ return rewriter.notifyMatchFailure(maskedOp, "no address space");
+
+ if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+ amdgpu::AddressSpace::FatRawBuffer)
+ return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+
+ return success();
+}
+
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+ vector::MaskedLoadOp maskedOp) {
+ VectorType vectorType = maskedOp.getVectorType();
+ Value load = builder.create<vector::LoadOp>(
+ loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(
+ loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
+ return res;
+}
+
+static constexpr char kMaskedloadNeedsMask[] =
+ "amdgpu.buffer_maskedload_needs_mask";
+
+namespace {
+
+struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
+ PatternRewriter &rewriter) const override {
+ if (maskedOp->hasAttr(kMaskedloadNeedsMask))
+ return failure();
+
+ if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+ return failure();
+ }
+
+ Location loc = maskedOp.getLoc();
+ Value src = maskedOp.getBase();
+
+ VectorType vectorType = maskedOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+ SmallVector<OpFoldResult> indices = maskedOp.getIndices();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+ SmallVector<OpFoldResult> strides =
+ stridedMetadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
+ OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+ memref::LinearizedMemRefInfo linearizedInfo;
+ OpFoldResult linearizedIndices;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+ elementBitWidth, offset, sizes,
+ strides, indices);
+
+ // delta = bufferSize - linearizedOffset
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value linearIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ Value totalSize = getValueOrCreateConstantIndexOp(
+ rewriter, loc, linearizedInfo.linearizedSize);
+ Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+ // 1) check if delta < vectorSize
+ Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+
+ // 2) check if (detla % elements_per_word != 0)
+ Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+ loc, llvm::divideCeil(32, elementBitWidth));
+ Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne,
+ rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ // We take the fallback of maskedload default lowering only it is both
+ // out-of-bounds and not word aligned. The fallback ensures correct results
+ // when loading at the boundary of the buffer since buffer load returns
+ // inconsistent zeros for the whole word when boundary is crossed.
+ Value ifCondition =
+ rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ Operation *read = builder.clone(*maskedOp.getOperation());
+ read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
+ Value readResult = read->getResult(0);
+ builder.create<scf::YieldOp>(loc, readResult);
+ };
+
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
+ auto ifOp =
+ rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+
+ rewriter.replaceOp(maskedOp, ifOp);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+}
+
+struct AmdgpuMaskedloadToLoadPass final
+ : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateAmdgpuMaskedloadToLoadPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
deleted file mode 100644
index f5b12a9524cc9..0000000000000
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
-
-/// This pattern supports lowering of:
-/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
-/// `vector.broadcast` if all of the following hold:
-/// - The transfer op is masked.
-/// - The memref is in buffer address space.
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-/// result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
-/// pass.
-static LogicalResult transferPreconditions(
- PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
- bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
- if (!xferOp.getMask())
- return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
-
- // Permutations are handled by VectorToSCF or
- // populateVectorTransferPermutationMapLoweringPatterns.
- // We let the 0-d corner case pass-through as it is supported.
- SmallVector<unsigned> broadcastedDims;
- if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
- &broadcastedDims))
- return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
-
- auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
- if (!memRefType)
- return rewriter.notifyMatchFailure(xferOp, "not a memref source");
-
- Attribute addrSpace = memRefType.getMemorySpace();
- if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
- return rewriter.notifyMatchFailure(xferOp, "no address space");
-
- if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
- amdgpu::AddressSpace::FatRawBuffer)
- return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
-
- // Non-unit strides are handled by VectorToSCF.
- if (!memRefType.isLastDimUnitStride())
- return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
-
- if (memRefType.getElementTypeBitWidth() < 8)
- return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
-
- // If there is broadcasting involved then we first load the unbroadcasted
- // vector, and then broadcast it with `vector.broadcast`.
- ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
- SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
- for (unsigned i : broadcastedDims)
- unbroadcastedVectorShape[i] = 1;
- unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
- unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
- requiresBroadcasting = !broadcastedDims.empty();
-
- // `vector.load` supports vector types as memref's elements only when the
- // resulting vector type is the same as the element type.
- auto memrefElTy = memRefType.getElementType();
- if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
- return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
-
- // Otherwise, element types of the memref and the vector must match.
- if (!isa<VectorType>(memrefElTy) &&
- memrefElTy != xferOp.getVectorType().getElementType())
- return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
-
- // Out-of-bounds dims are handled by MaterializeTransferMask.
- if (xferOp.hasOutOfBoundsDim())
- return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
-
- if (xferOp.getVectorType().getRank() != 1)
- // vector.maskedload operates on 1-D vectors.
- return rewriter.notifyMatchFailure(
- xferOp, "vector type is not rank 1, can't create masked load, needs "
- "VectorToSCF");
-
- return success();
-}
-
-static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
- vector::TransferReadOp readOp,
- bool requiresBroadcasting,
- VectorType unbroadcastedVectorType) {
- Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = builder.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
- Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
- }
- return res;
-}
-
-static constexpr char kTransferReadNeedsMask[] =
- "amdgpu.buffer_transfer_read_needs_mask";
-
-namespace {
-
-struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
- PatternRewriter &rewriter) const override {
- if (readOp->hasAttr(kTransferReadNeedsMask))
- return failure();
-
- bool requiresBroadcasting = false;
- VectorType unbroadcastedVectorType;
- if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
- unbroadcastedVectorType))) {
- return failure();
- }
-
- Location loc = readOp.getLoc();
- Value src = readOp.getBase();
-
- VectorType vectorType = readOp.getVectorType();
- int64_t vectorSize = vectorType.getNumElements();
- int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
- SmallVector<OpFoldResult> indices = readOp.getIndices();
-
- auto stridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
- SmallVector<OpFoldResult> strides =
- stridedMetadata.getConstifiedMixedStrides();
- SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
- OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
- memref::LinearizedMemRefInfo linearizedInfo;
- OpFoldResult linearizedIndices;
- std::tie(linearizedInfo, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
- elementBitWidth, offset, sizes,
- strides, indices);
-
- // delta = bufferSize - linearizedOffset
- Value vectorSizeOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
- Value linearIndex =
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
- Value totalSize = getValueOrCreateConstantIndexOp(
- rewriter, loc, linearizedInfo.linearizedSize);
- Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
-
- // 1) check if delta < vectorSize
- Value isOutofBounds = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
-
- // 2) check if (detla % elements_per_word != 0)
- Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
- loc, llvm::divideCeil(32, elementBitWidth));
- Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne,
- rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
- // We take the fallback of transfer_read default lowering only it is both
- // out-of-bounds and not word aligned. The fallback ensures correct results
- // when loading at the boundary of the buffer since buffer load returns
- // inconsistent zeros for the whole word when boundary is crossed.
- Value ifCondition =
- rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
-
- auto thenBuilder = [&](OpBuilder &builder, Location loc) {
- Operation *read = builder.clone(*readOp.getOperation());
- read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
- Value readResult = read->getResult(0);
- builder.create<scf::YieldOp>(loc, readResult);
- };
-
- auto elseBuilder = [&](OpBuilder &builder, Location loc) {
- Value res = createVectorLoadForMaskedLoad(
- builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
- rewriter.create<scf::YieldOp>(loc, res);
- };
-
- auto ifOp =
- rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
-
- rewriter.replaceOp(readOp, ifOp);
-
- return success();
- }
-};
-
-} // namespace
-
-void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<TransferReadLowering>(patterns.getContext(), benefit);
-}
-
-struct AmdgpuTransferReadToLoadPass final
- : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
- AmdgpuTransferReadToLoadPass> {
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateAmdgpuTransferReadToLoadPatterns(patterns);
- if (failed(applyPatter...
[truncated]
|
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much nicer, thanks!
| def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> { | ||
| let summary = "Lower the operations from the vector maskedload to vector load"; | ||
| let description = [{ | ||
| This pass creates a transfer read op lowering optimization. The lowering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the description still refers to transfer reads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm good catch... Unfortunately PR just get merged. I'll leave a note to myself of fixing this next time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the test names also seem like they may need updating
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep agreed!
This PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This simplified the implementation of the lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline.