Skip to content

Conversation

@krzysz00
Copy link
Contributor

The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource.

This commet add checks to resolve the overlap.

The two paterns for handlig vector.maskedload on AMD GPUs had an
overlap - both the "scalar mask becomes an if statement" pattern and
the "masked loads become a normal load + a select on buffers" patterns
could handle a load with a broadcast mask on a fat buffer resource.

This commet add checks to resolve the overlap.
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes

The two paterns for handlig vector.maskedload on AMD GPUs had an overlap - both the "scalar mask becomes an if statement" pattern and the "masked loads become a normal load + a select on buffers" patterns could handle a load with a broadcast mask on a fat buffer resource.

This commet add checks to resolve the overlap.


Full diff: https://github.com/llvm/llvm-project/pull/159805.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp (+24-10)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c166e0a..89ef51f922cad 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
 
 /// This pattern supports lowering of: `vector.maskedload` to `vector.load`
 /// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
-                                           vector::MaskedLoadOp maskedOp) {
-  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+  auto memRefType = dyn_cast<MemRefType>(type);
   if (!memRefType)
-    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+    return failure();
 
   Attribute addrSpace = memRefType.getMemorySpace();
   if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+    return failure();
 
   if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
       amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+    return failure();
 
   return success();
 }
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
   LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
                                 PatternRewriter &rewriter) const override {
     if (maskedOp->hasAttr(kMaskedloadNeedsMask))
-      return failure();
+      return rewriter.notifyMatchFailure(maskedOp, "already rewritten");
 
-    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
-      return failure();
+    if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
+      return rewriter.notifyMatchFailure(
+          maskedOp, "isn't a load from a fat buffer resource");
     }
 
     // Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
 
   LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
                                 PatternRewriter &rewriter) const override {
+    if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+      return rewriter.notifyMatchFailure(
+          loadOp, "buffer loads are handled by a more specialized pattern");
+
     FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
     if (failed(maybeCond)) {
-      return failure();
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "isn't loading a broadcasted scalar");
     }
 
     Value cond = maybeCond.value();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
 
   LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
                                 PatternRewriter &rewriter) const override {
+    // A condition-free implementation of fully masked stores requires
+    // 1) an accessor for the num_records field on buffer resources/fat pointers
+    // 2) knowledge that said field will always be set accurately - that is,
+    // that writes to x < num_records of offset wouldn't trap, which is
+    // something a pattern user would need to assert or we'd need to prove.
+    //
+    // Therefore, conditional stores to buffers still go down this path at
+    // present.
+
     FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
     if (failed(maybeCond)) {
       return failure();

@krzysz00
Copy link
Contributor Author

@Groverkss Ping

@krzysz00
Copy link
Contributor Author

krzysz00 commented Oct 9, 2025

Ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants