-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][AMDGPU][NFC] Fix overlapping masked load refinements #159805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource. This commet add checks to resolve the overlap.
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-amdgpu Author: Krzysztof Drewniak (krzysz00) ChangesThe two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource. This commet add checks to resolve the overlap. Full diff: https://github.com/llvm/llvm-project/pull/159805.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c166e0a..89ef51f922cad 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
/// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
- vector::MaskedLoadOp maskedOp) {
- auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+ auto memRefType = dyn_cast<MemRefType>(type);
if (!memRefType)
- return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+ return failure();
Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
- return rewriter.notifyMatchFailure(maskedOp, "no address space");
+ return failure();
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
- return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+ return failure();
return success();
}
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
PatternRewriter &rewriter) const override {
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
- return failure();
+ return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
- if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
- return failure();
+ if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
+ return rewriter.notifyMatchFailure(
+ maskedOp, "isn't a load from a fat buffer resource");
}
// Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
PatternRewriter &rewriter) const override {
+ if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+ return rewriter.notifyMatchFailure(
+ loadOp, "buffer loads are handled by a more specialized pattern");
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
if (failed(maybeCond)) {
- return failure();
+ return rewriter.notifyMatchFailure(loadOp,
+ "isn't loading a broadcasted scalar");
}
Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
PatternRewriter &rewriter) const override {
+ // A condition-free implementation of fully masked stores requires
+ // 1) an accessor for the num_records field on buffer resources/fat pointers
+ // 2) knowledge that said field will always be set accurately - that is,
+ // that writes to x < num_records of offset wouldn't trap, which is
+ // something a pattern user would need to assert or we'd need to prove.
+ //
+ // Therefore, conditional stores to buffers still go down this path at
+ // present.
+
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
if (failed(maybeCond)) {
return failure();
|
|
@Groverkss Ping |
|
Ping |
The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource.
This commet add checks to resolve the overlap.