Skip to content

Conversation

@jerryyin
Copy link
Member

@jerryyin jerryyin commented Jul 2, 2025

This PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This simplified the implementation of the lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline.

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-amdgpu

Author: Zhuoran Yin (jerryyin)

Changes

This PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This allows simplified lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline.


Patch is 29.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146705.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+3-3)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+2-2)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+1-1)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp (+167)
  • (removed) mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp (-239)
  • (renamed) mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir (+26-52)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a52ee2ee89caf..cc2f543e79f69 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -23,7 +23,7 @@ namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
 #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
-#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
 void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit = 1);
 
-void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
-                                              PatternBenefit benefit = 1);
+void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 1);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 0e858108acf35..8d0e6829ab0cc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
   ];
 }
 
-def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
-  let summary = "Lower the operations from the vector transfer_read to vector load";
+def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
+  let summary = "Lower the operations from the vector maskedload to vector load";
   let description = [{
     This pass creates a transfer read op lowering optimization. The lowering
     will produce a conditional check at runtime. If within bounds, a vector
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8709a27e0168e..17bbe54ea6c0c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
   ResolveStridedMetadata.cpp
-  TransferReadToLoad.cpp
+  MaskedloadToLoad.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
new file mode 100644
index 0000000000000..9a368f372c296
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -0,0 +1,167 @@
+//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
+/// and `arith.select` if the memref is in buffer address space.
+static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
+                                           vector::MaskedLoadOp maskedOp) {
+  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+  if (!memRefType)
+    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+
+  Attribute addrSpace = memRefType.getMemorySpace();
+  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
+    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+
+  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+      amdgpu::AddressSpace::FatRawBuffer)
+    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+
+  return success();
+}
+
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+                                           vector::MaskedLoadOp maskedOp) {
+  VectorType vectorType = maskedOp.getVectorType();
+  Value load = builder.create<vector::LoadOp>(
+      loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+  Value res = builder.create<arith::SelectOp>(
+      loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
+  return res;
+}
+
+static constexpr char kMaskedloadNeedsMask[] =
+    "amdgpu.buffer_maskedload_needs_mask";
+
+namespace {
+
+struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
+                                PatternRewriter &rewriter) const override {
+    if (maskedOp->hasAttr(kMaskedloadNeedsMask))
+      return failure();
+
+    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+      return failure();
+    }
+
+    Location loc = maskedOp.getLoc();
+    Value src = maskedOp.getBase();
+
+    VectorType vectorType = maskedOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    SmallVector<OpFoldResult> indices = maskedOp.getIndices();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    SmallVector<OpFoldResult> strides =
+        stridedMetadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
+    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    OpFoldResult linearizedIndices;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, offset, sizes,
+                                                 strides, indices);
+
+    // delta = bufferSize - linearizedOffset
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    Value totalSize = getValueOrCreateConstantIndexOp(
+        rewriter, loc, linearizedInfo.linearizedSize);
+    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+    // 1) check if delta < vectorSize
+    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+
+    // 2) check if (detla % elements_per_word != 0)
+    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+        loc, llvm::divideCeil(32, elementBitWidth));
+    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne,
+        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+    // We take the fallback of maskedload default lowering only it is both
+    // out-of-bounds and not word aligned. The fallback ensures correct results
+    // when loading at the boundary of the buffer since buffer load returns
+    // inconsistent zeros for the whole word when boundary is crossed.
+    Value ifCondition =
+        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*maskedOp.getOperation());
+      read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
+      Value readResult = read->getResult(0);
+      builder.create<scf::YieldOp>(loc, readResult);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+
+    rewriter.replaceOp(maskedOp, ifOp);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+}
+
+struct AmdgpuMaskedloadToLoadPass final
+    : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuMaskedloadToLoadPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
deleted file mode 100644
index f5b12a9524cc9..0000000000000
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
-
-/// This pattern supports lowering of:
-/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
-/// `vector.broadcast` if all of the following hold:
-/// - The transfer op is masked.
-/// - The memref is in buffer address space.
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-///   result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
-/// pass.
-static LogicalResult transferPreconditions(
-    PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
-    bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
-  if (!xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
-
-  // Permutations are handled by VectorToSCF or
-  // populateVectorTransferPermutationMapLoweringPatterns.
-  // We let the 0-d corner case pass-through as it is supported.
-  SmallVector<unsigned> broadcastedDims;
-  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
-          &broadcastedDims))
-    return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
-
-  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
-  if (!memRefType)
-    return rewriter.notifyMatchFailure(xferOp, "not a memref source");
-
-  Attribute addrSpace = memRefType.getMemorySpace();
-  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(xferOp, "no address space");
-
-  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
-      amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
-
-  // Non-unit strides are handled by VectorToSCF.
-  if (!memRefType.isLastDimUnitStride())
-    return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
-
-  if (memRefType.getElementTypeBitWidth() < 8)
-    return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
-
-  // If there is broadcasting involved then we first load the unbroadcasted
-  // vector, and then broadcast it with `vector.broadcast`.
-  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
-  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
-  for (unsigned i : broadcastedDims)
-    unbroadcastedVectorShape[i] = 1;
-  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
-      unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
-  requiresBroadcasting = !broadcastedDims.empty();
-
-  // `vector.load` supports vector types as memref's elements only when the
-  // resulting vector type is the same as the element type.
-  auto memrefElTy = memRefType.getElementType();
-  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
-    return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
-
-  // Otherwise, element types of the memref and the vector must match.
-  if (!isa<VectorType>(memrefElTy) &&
-      memrefElTy != xferOp.getVectorType().getElementType())
-    return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
-
-  // Out-of-bounds dims are handled by MaterializeTransferMask.
-  if (xferOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
-
-  if (xferOp.getVectorType().getRank() != 1)
-    // vector.maskedload operates on 1-D vectors.
-    return rewriter.notifyMatchFailure(
-        xferOp, "vector type is not rank 1, can't create masked load, needs "
-                "VectorToSCF");
-
-  return success();
-}
-
-static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
-                                           vector::TransferReadOp readOp,
-                                           bool requiresBroadcasting,
-                                           VectorType unbroadcastedVectorType) {
-  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                               readOp.getPadding());
-  Value load = builder.create<vector::LoadOp>(
-      loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
-  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                              readOp.getMask(), load, fill);
-  // Insert a broadcasting op if required.
-  if (requiresBroadcasting) {
-    res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
-  }
-  return res;
-}
-
-static constexpr char kTransferReadNeedsMask[] =
-    "amdgpu.buffer_transfer_read_needs_mask";
-
-namespace {
-
-struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) const override {
-    if (readOp->hasAttr(kTransferReadNeedsMask))
-      return failure();
-
-    bool requiresBroadcasting = false;
-    VectorType unbroadcastedVectorType;
-    if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
-                                     unbroadcastedVectorType))) {
-      return failure();
-    }
-
-    Location loc = readOp.getLoc();
-    Value src = readOp.getBase();
-
-    VectorType vectorType = readOp.getVectorType();
-    int64_t vectorSize = vectorType.getNumElements();
-    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
-    SmallVector<OpFoldResult> indices = readOp.getIndices();
-
-    auto stridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
-    SmallVector<OpFoldResult> strides =
-        stridedMetadata.getConstifiedMixedStrides();
-    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
-    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
-    memref::LinearizedMemRefInfo linearizedInfo;
-    OpFoldResult linearizedIndices;
-    std::tie(linearizedInfo, linearizedIndices) =
-        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
-                                                 elementBitWidth, offset, sizes,
-                                                 strides, indices);
-
-    // delta = bufferSize - linearizedOffset
-    Value vectorSizeOffset =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
-    Value linearIndex =
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
-    Value totalSize = getValueOrCreateConstantIndexOp(
-        rewriter, loc, linearizedInfo.linearizedSize);
-    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
-
-    // 1) check if delta < vectorSize
-    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
-
-    // 2) check if (detla % elements_per_word != 0)
-    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
-        loc, llvm::divideCeil(32, elementBitWidth));
-    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ne,
-        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
-        rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
-    // We take the fallback of transfer_read default lowering only it is both
-    // out-of-bounds and not word aligned. The fallback ensures correct results
-    // when loading at the boundary of the buffer since buffer load returns
-    // inconsistent zeros for the whole word when boundary is crossed.
-    Value ifCondition =
-        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
-
-    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
-      Operation *read = builder.clone(*readOp.getOperation());
-      read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
-      Value readResult = read->getResult(0);
-      builder.create<scf::YieldOp>(loc, readResult);
-    };
-
-    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
-      Value res = createVectorLoadForMaskedLoad(
-          builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
-      rewriter.create<scf::YieldOp>(loc, res);
-    };
-
-    auto ifOp =
-        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
-
-    rewriter.replaceOp(readOp, ifOp);
-
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<TransferReadLowering>(patterns.getContext(), benefit);
-}
-
-struct AmdgpuTransferReadToLoadPass final
-    : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
-          AmdgpuTransferReadToLoadPass> {
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateAmdgpuTransferReadToLoadPatterns(patterns);
-    if (failed(applyPatter...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Zhuoran Yin (jerryyin)

Changes

This PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This allows simplified lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline.


Patch is 29.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146705.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+3-3)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+2-2)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+1-1)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp (+167)
  • (removed) mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp (-239)
  • (renamed) mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir (+26-52)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a52ee2ee89caf..cc2f543e79f69 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -23,7 +23,7 @@ namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
 #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
-#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
 void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit = 1);
 
-void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
-                                              PatternBenefit benefit = 1);
+void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 1);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 0e858108acf35..8d0e6829ab0cc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
   ];
 }
 
-def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
-  let summary = "Lower the operations from the vector transfer_read to vector load";
+def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
+  let summary = "Lower the operations from the vector maskedload to vector load";
   let description = [{
     This pass creates a transfer read op lowering optimization. The lowering
     will produce a conditional check at runtime. If within bounds, a vector
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8709a27e0168e..17bbe54ea6c0c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
   ResolveStridedMetadata.cpp
-  TransferReadToLoad.cpp
+  MaskedloadToLoad.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
new file mode 100644
index 0000000000000..9a368f372c296
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -0,0 +1,167 @@
+//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
+/// and `arith.select` if the memref is in buffer address space.
+static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
+                                           vector::MaskedLoadOp maskedOp) {
+  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+  if (!memRefType)
+    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+
+  Attribute addrSpace = memRefType.getMemorySpace();
+  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
+    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+
+  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+      amdgpu::AddressSpace::FatRawBuffer)
+    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+
+  return success();
+}
+
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+                                           vector::MaskedLoadOp maskedOp) {
+  VectorType vectorType = maskedOp.getVectorType();
+  Value load = builder.create<vector::LoadOp>(
+      loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+  Value res = builder.create<arith::SelectOp>(
+      loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
+  return res;
+}
+
+static constexpr char kMaskedloadNeedsMask[] =
+    "amdgpu.buffer_maskedload_needs_mask";
+
+namespace {
+
+struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
+                                PatternRewriter &rewriter) const override {
+    if (maskedOp->hasAttr(kMaskedloadNeedsMask))
+      return failure();
+
+    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+      return failure();
+    }
+
+    Location loc = maskedOp.getLoc();
+    Value src = maskedOp.getBase();
+
+    VectorType vectorType = maskedOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    SmallVector<OpFoldResult> indices = maskedOp.getIndices();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    SmallVector<OpFoldResult> strides =
+        stridedMetadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
+    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    OpFoldResult linearizedIndices;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, offset, sizes,
+                                                 strides, indices);
+
+    // delta = bufferSize - linearizedOffset
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    Value totalSize = getValueOrCreateConstantIndexOp(
+        rewriter, loc, linearizedInfo.linearizedSize);
+    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+    // 1) check if delta < vectorSize
+    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+
+    // 2) check if (detla % elements_per_word != 0)
+    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+        loc, llvm::divideCeil(32, elementBitWidth));
+    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne,
+        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+    // We take the fallback of maskedload default lowering only it is both
+    // out-of-bounds and not word aligned. The fallback ensures correct results
+    // when loading at the boundary of the buffer since buffer load returns
+    // inconsistent zeros for the whole word when boundary is crossed.
+    Value ifCondition =
+        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*maskedOp.getOperation());
+      read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
+      Value readResult = read->getResult(0);
+      builder.create<scf::YieldOp>(loc, readResult);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+
+    rewriter.replaceOp(maskedOp, ifOp);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+}
+
+struct AmdgpuMaskedloadToLoadPass final
+    : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuMaskedloadToLoadPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
deleted file mode 100644
index f5b12a9524cc9..0000000000000
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
-
-/// This pattern supports lowering of:
-/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
-/// `vector.broadcast` if all of the following hold:
-/// - The transfer op is masked.
-/// - The memref is in buffer address space.
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-///   result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
-/// pass.
-static LogicalResult transferPreconditions(
-    PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
-    bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
-  if (!xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
-
-  // Permutations are handled by VectorToSCF or
-  // populateVectorTransferPermutationMapLoweringPatterns.
-  // We let the 0-d corner case pass-through as it is supported.
-  SmallVector<unsigned> broadcastedDims;
-  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
-          &broadcastedDims))
-    return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
-
-  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
-  if (!memRefType)
-    return rewriter.notifyMatchFailure(xferOp, "not a memref source");
-
-  Attribute addrSpace = memRefType.getMemorySpace();
-  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(xferOp, "no address space");
-
-  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
-      amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
-
-  // Non-unit strides are handled by VectorToSCF.
-  if (!memRefType.isLastDimUnitStride())
-    return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
-
-  if (memRefType.getElementTypeBitWidth() < 8)
-    return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
-
-  // If there is broadcasting involved then we first load the unbroadcasted
-  // vector, and then broadcast it with `vector.broadcast`.
-  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
-  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
-  for (unsigned i : broadcastedDims)
-    unbroadcastedVectorShape[i] = 1;
-  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
-      unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
-  requiresBroadcasting = !broadcastedDims.empty();
-
-  // `vector.load` supports vector types as memref's elements only when the
-  // resulting vector type is the same as the element type.
-  auto memrefElTy = memRefType.getElementType();
-  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
-    return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
-
-  // Otherwise, element types of the memref and the vector must match.
-  if (!isa<VectorType>(memrefElTy) &&
-      memrefElTy != xferOp.getVectorType().getElementType())
-    return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
-
-  // Out-of-bounds dims are handled by MaterializeTransferMask.
-  if (xferOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
-
-  if (xferOp.getVectorType().getRank() != 1)
-    // vector.maskedload operates on 1-D vectors.
-    return rewriter.notifyMatchFailure(
-        xferOp, "vector type is not rank 1, can't create masked load, needs "
-                "VectorToSCF");
-
-  return success();
-}
-
-static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
-                                           vector::TransferReadOp readOp,
-                                           bool requiresBroadcasting,
-                                           VectorType unbroadcastedVectorType) {
-  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                               readOp.getPadding());
-  Value load = builder.create<vector::LoadOp>(
-      loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
-  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                              readOp.getMask(), load, fill);
-  // Insert a broadcasting op if required.
-  if (requiresBroadcasting) {
-    res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
-  }
-  return res;
-}
-
-static constexpr char kTransferReadNeedsMask[] =
-    "amdgpu.buffer_transfer_read_needs_mask";
-
-namespace {
-
-struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) const override {
-    if (readOp->hasAttr(kTransferReadNeedsMask))
-      return failure();
-
-    bool requiresBroadcasting = false;
-    VectorType unbroadcastedVectorType;
-    if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
-                                     unbroadcastedVectorType))) {
-      return failure();
-    }
-
-    Location loc = readOp.getLoc();
-    Value src = readOp.getBase();
-
-    VectorType vectorType = readOp.getVectorType();
-    int64_t vectorSize = vectorType.getNumElements();
-    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
-    SmallVector<OpFoldResult> indices = readOp.getIndices();
-
-    auto stridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
-    SmallVector<OpFoldResult> strides =
-        stridedMetadata.getConstifiedMixedStrides();
-    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
-    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
-    memref::LinearizedMemRefInfo linearizedInfo;
-    OpFoldResult linearizedIndices;
-    std::tie(linearizedInfo, linearizedIndices) =
-        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
-                                                 elementBitWidth, offset, sizes,
-                                                 strides, indices);
-
-    // delta = bufferSize - linearizedOffset
-    Value vectorSizeOffset =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
-    Value linearIndex =
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
-    Value totalSize = getValueOrCreateConstantIndexOp(
-        rewriter, loc, linearizedInfo.linearizedSize);
-    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
-
-    // 1) check if delta < vectorSize
-    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
-
-    // 2) check if (detla % elements_per_word != 0)
-    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
-        loc, llvm::divideCeil(32, elementBitWidth));
-    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ne,
-        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
-        rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
-    // We take the fallback of transfer_read default lowering only it is both
-    // out-of-bounds and not word aligned. The fallback ensures correct results
-    // when loading at the boundary of the buffer since buffer load returns
-    // inconsistent zeros for the whole word when boundary is crossed.
-    Value ifCondition =
-        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
-
-    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
-      Operation *read = builder.clone(*readOp.getOperation());
-      read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
-      Value readResult = read->getResult(0);
-      builder.create<scf::YieldOp>(loc, readResult);
-    };
-
-    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
-      Value res = createVectorLoadForMaskedLoad(
-          builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
-      rewriter.create<scf::YieldOp>(loc, res);
-    };
-
-    auto ifOp =
-        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
-
-    rewriter.replaceOp(readOp, ifOp);
-
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<TransferReadLowering>(patterns.getContext(), benefit);
-}
-
-struct AmdgpuTransferReadToLoadPass final
-    : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
-          AmdgpuTransferReadToLoadPass> {
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateAmdgpuTransferReadToLoadPatterns(patterns);
-    if (failed(applyPatter...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Zhuoran Yin (jerryyin)

Changes

This PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This allows simplified lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline.


Patch is 29.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146705.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+3-3)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+2-2)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+1-1)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp (+167)
  • (removed) mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp (-239)
  • (renamed) mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir (+26-52)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index a52ee2ee89caf..cc2f543e79f69 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -23,7 +23,7 @@ namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
 #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
-#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
 void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit = 1);
 
-void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
-                                              PatternBenefit benefit = 1);
+void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
+                                            PatternBenefit benefit = 1);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 0e858108acf35..8d0e6829ab0cc 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
   ];
 }
 
-def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
-  let summary = "Lower the operations from the vector transfer_read to vector load";
+def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
+  let summary = "Lower the operations from the vector maskedload to vector load";
   let description = [{
     This pass creates a transfer read op lowering optimization. The lowering
     will produce a conditional check at runtime. If within bounds, a vector
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 8709a27e0168e..17bbe54ea6c0c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
   ResolveStridedMetadata.cpp
-  TransferReadToLoad.cpp
+  MaskedloadToLoad.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
new file mode 100644
index 0000000000000..9a368f372c296
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -0,0 +1,167 @@
+//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace mlir::amdgpu
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
+/// and `arith.select` if the memref is in buffer address space.
+static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
+                                           vector::MaskedLoadOp maskedOp) {
+  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+  if (!memRefType)
+    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+
+  Attribute addrSpace = memRefType.getMemorySpace();
+  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
+    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+
+  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
+      amdgpu::AddressSpace::FatRawBuffer)
+    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+
+  return success();
+}
+
+static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
+                                           vector::MaskedLoadOp maskedOp) {
+  VectorType vectorType = maskedOp.getVectorType();
+  Value load = builder.create<vector::LoadOp>(
+      loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
+  Value res = builder.create<arith::SelectOp>(
+      loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
+  return res;
+}
+
+static constexpr char kMaskedloadNeedsMask[] =
+    "amdgpu.buffer_maskedload_needs_mask";
+
+namespace {
+
+struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
+                                PatternRewriter &rewriter) const override {
+    if (maskedOp->hasAttr(kMaskedloadNeedsMask))
+      return failure();
+
+    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+      return failure();
+    }
+
+    Location loc = maskedOp.getLoc();
+    Value src = maskedOp.getBase();
+
+    VectorType vectorType = maskedOp.getVectorType();
+    int64_t vectorSize = vectorType.getNumElements();
+    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
+    SmallVector<OpFoldResult> indices = maskedOp.getIndices();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
+    SmallVector<OpFoldResult> strides =
+        stridedMetadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
+    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    OpFoldResult linearizedIndices;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, offset, sizes,
+                                                 strides, indices);
+
+    // delta = bufferSize - linearizedOffset
+    Value vectorSizeOffset =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+    Value linearIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    Value totalSize = getValueOrCreateConstantIndexOp(
+        rewriter, loc, linearizedInfo.linearizedSize);
+    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
+
+    // 1) check if delta < vectorSize
+    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
+
+    // 2) check if (detla % elements_per_word != 0)
+    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
+        loc, llvm::divideCeil(32, elementBitWidth));
+    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne,
+        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+    // We take the fallback of maskedload default lowering only it is both
+    // out-of-bounds and not word aligned. The fallback ensures correct results
+    // when loading at the boundary of the buffer since buffer load returns
+    // inconsistent zeros for the whole word when boundary is crossed.
+    Value ifCondition =
+        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
+
+    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+      Operation *read = builder.clone(*maskedOp.getOperation());
+      read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
+      Value readResult = read->getResult(0);
+      builder.create<scf::YieldOp>(loc, readResult);
+    };
+
+    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+      Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
+      rewriter.create<scf::YieldOp>(loc, res);
+    };
+
+    auto ifOp =
+        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
+
+    rewriter.replaceOp(maskedOp, ifOp);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
+}
+
+struct AmdgpuMaskedloadToLoadPass final
+    : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuMaskedloadToLoadPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
deleted file mode 100644
index f5b12a9524cc9..0000000000000
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ /dev/null
@@ -1,239 +0,0 @@
-//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/MathExtras.h"
-
-namespace mlir::amdgpu {
-#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
-} // namespace mlir::amdgpu
-
-using namespace mlir;
-using namespace mlir::amdgpu;
-
-/// This pattern supports lowering of:
-/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and
-/// `vector.broadcast` if all of the following hold:
-/// - The transfer op is masked.
-/// - The memref is in buffer address space.
-/// - Stride of most minor memref dimension must be 1.
-/// - Out-of-bounds masking is not required.
-/// - If the memref's element type is a vector type then it coincides with the
-///   result type.
-/// - The permutation map doesn't perform permutation (broadcasting is allowed).
-/// Note: those conditions mostly come from TransferReadToVectorLoadLowering
-/// pass.
-static LogicalResult transferPreconditions(
-    PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
-    bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
-  if (!xferOp.getMask())
-    return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer");
-
-  // Permutations are handled by VectorToSCF or
-  // populateVectorTransferPermutationMapLoweringPatterns.
-  // We let the 0-d corner case pass-through as it is supported.
-  SmallVector<unsigned> broadcastedDims;
-  if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting(
-          &broadcastedDims))
-    return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast");
-
-  auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
-  if (!memRefType)
-    return rewriter.notifyMatchFailure(xferOp, "not a memref source");
-
-  Attribute addrSpace = memRefType.getMemorySpace();
-  if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(xferOp, "no address space");
-
-  if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
-      amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(xferOp, "not in buffer address space");
-
-  // Non-unit strides are handled by VectorToSCF.
-  if (!memRefType.isLastDimUnitStride())
-    return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF");
-
-  if (memRefType.getElementTypeBitWidth() < 8)
-    return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type");
-
-  // If there is broadcasting involved then we first load the unbroadcasted
-  // vector, and then broadcast it with `vector.broadcast`.
-  ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
-  SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
-  for (unsigned i : broadcastedDims)
-    unbroadcastedVectorShape[i] = 1;
-  unbroadcastedVectorType = xferOp.getVectorType().cloneWith(
-      unbroadcastedVectorShape, xferOp.getVectorType().getElementType());
-  requiresBroadcasting = !broadcastedDims.empty();
-
-  // `vector.load` supports vector types as memref's elements only when the
-  // resulting vector type is the same as the element type.
-  auto memrefElTy = memRefType.getElementType();
-  if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
-    return rewriter.notifyMatchFailure(xferOp, "incompatible element type");
-
-  // Otherwise, element types of the memref and the vector must match.
-  if (!isa<VectorType>(memrefElTy) &&
-      memrefElTy != xferOp.getVectorType().getElementType())
-    return rewriter.notifyMatchFailure(xferOp, "non-matching element type");
-
-  // Out-of-bounds dims are handled by MaterializeTransferMask.
-  if (xferOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask");
-
-  if (xferOp.getVectorType().getRank() != 1)
-    // vector.maskedload operates on 1-D vectors.
-    return rewriter.notifyMatchFailure(
-        xferOp, "vector type is not rank 1, can't create masked load, needs "
-                "VectorToSCF");
-
-  return success();
-}
-
-static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
-                                           vector::TransferReadOp readOp,
-                                           bool requiresBroadcasting,
-                                           VectorType unbroadcastedVectorType) {
-  Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
-                                               readOp.getPadding());
-  Value load = builder.create<vector::LoadOp>(
-      loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
-  Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
-                                              readOp.getMask(), load, fill);
-  // Insert a broadcasting op if required.
-  if (requiresBroadcasting) {
-    res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res);
-  }
-  return res;
-}
-
-static constexpr char kTransferReadNeedsMask[] =
-    "amdgpu.buffer_transfer_read_needs_mask";
-
-namespace {
-
-struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) const override {
-    if (readOp->hasAttr(kTransferReadNeedsMask))
-      return failure();
-
-    bool requiresBroadcasting = false;
-    VectorType unbroadcastedVectorType;
-    if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting,
-                                     unbroadcastedVectorType))) {
-      return failure();
-    }
-
-    Location loc = readOp.getLoc();
-    Value src = readOp.getBase();
-
-    VectorType vectorType = readOp.getVectorType();
-    int64_t vectorSize = vectorType.getNumElements();
-    int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
-    SmallVector<OpFoldResult> indices = readOp.getIndices();
-
-    auto stridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
-    SmallVector<OpFoldResult> strides =
-        stridedMetadata.getConstifiedMixedStrides();
-    SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
-    OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
-    memref::LinearizedMemRefInfo linearizedInfo;
-    OpFoldResult linearizedIndices;
-    std::tie(linearizedInfo, linearizedIndices) =
-        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
-                                                 elementBitWidth, offset, sizes,
-                                                 strides, indices);
-
-    // delta = bufferSize - linearizedOffset
-    Value vectorSizeOffset =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
-    Value linearIndex =
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
-    Value totalSize = getValueOrCreateConstantIndexOp(
-        rewriter, loc, linearizedInfo.linearizedSize);
-    Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
-
-    // 1) check if delta < vectorSize
-    Value isOutofBounds = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
-
-    // 2) check if (detla % elements_per_word != 0)
-    Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
-        loc, llvm::divideCeil(32, elementBitWidth));
-    Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::ne,
-        rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
-        rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
-    // We take the fallback of transfer_read default lowering only it is both
-    // out-of-bounds and not word aligned. The fallback ensures correct results
-    // when loading at the boundary of the buffer since buffer load returns
-    // inconsistent zeros for the whole word when boundary is crossed.
-    Value ifCondition =
-        rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
-
-    auto thenBuilder = [&](OpBuilder &builder, Location loc) {
-      Operation *read = builder.clone(*readOp.getOperation());
-      read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr());
-      Value readResult = read->getResult(0);
-      builder.create<scf::YieldOp>(loc, readResult);
-    };
-
-    auto elseBuilder = [&](OpBuilder &builder, Location loc) {
-      Value res = createVectorLoadForMaskedLoad(
-          builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
-      rewriter.create<scf::YieldOp>(loc, res);
-    };
-
-    auto ifOp =
-        rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
-
-    rewriter.replaceOp(readOp, ifOp);
-
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<TransferReadLowering>(patterns.getContext(), benefit);
-}
-
-struct AmdgpuTransferReadToLoadPass final
-    : amdgpu::impl::AmdgpuTransferReadToLoadPassBase<
-          AmdgpuTransferReadToLoadPass> {
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateAmdgpuTransferReadToLoadPatterns(patterns);
-    if (failed(applyPatter...
[truncated]

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much nicer, thanks!

@Groverkss Groverkss merged commit 6a97b56 into main Jul 2, 2025
12 checks passed
@Groverkss Groverkss deleted the users/zyin/redirect-masked-load-to-load branch July 2, 2025 17:24
def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
let summary = "Lower the operations from the vector maskedload to vector load";
let description = [{
This pass creates a transfer read op lowering optimization. The lowering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the description still refers to transfer reads

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm good catch... Unfortunately PR just get merged. I'll leave a note to myself of fixing this next time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test names also seem like they may need updating

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep agreed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants