From c86b23a547512ee27c540bcb711823e719122d6c Mon Sep 17 00:00:00 2001 From: jerryyin Date: Tue, 1 Jul 2025 15:32:32 +0000 Subject: [PATCH] Redirect transfer read to masked load lowering Signed-off-by: jerryyin --- .../mlir/Dialect/AMDGPU/Transforms/Passes.h | 6 +- .../mlir/Dialect/AMDGPU/Transforms/Passes.td | 4 +- .../Dialect/AMDGPU/Transforms/CMakeLists.txt | 2 +- .../AMDGPU/Transforms/MaskedloadToLoad.cpp | 167 ++++++++++++ .../AMDGPU/Transforms/TransferReadToLoad.cpp | 239 ------------------ ...d-to-load.mlir => maskedload-to-load.mlir} | 78 ++---- 6 files changed, 199 insertions(+), 297 deletions(-) create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp delete mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp rename mlir/test/Dialect/AMDGPU/{transfer-read-to-load.mlir => maskedload-to-load.mlir} (56%) diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h index a52ee2ee89caf..cc2f543e79f69 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h @@ -23,7 +23,7 @@ namespace amdgpu { #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS -#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS +#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" @@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); +void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); } // namespace amdgpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td index 0e858108acf35..8d0e6829ab0cc 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td @@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> { ]; } -def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> { - let summary = "Lower the operations from the vector transfer_read to vector load"; +def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> { + let summary = "Lower the operations from the vector maskedload to vector load"; let description = [{ This pass creates a transfer read op lowering optimization. The lowering will produce a conditional check at runtime. If within bounds, a vector diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index 8709a27e0168e..17bbe54ea6c0c 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -1,7 +1,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms EmulateAtomics.cpp ResolveStridedMetadata.cpp - TransferReadToLoad.cpp + MaskedloadToLoad.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp new file mode 100644 index 0000000000000..9a368f372c296 --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -0,0 +1,167 @@ +//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::amdgpu { +#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" +} // namespace mlir::amdgpu + +using namespace mlir; +using namespace mlir::amdgpu; + +/// This pattern supports lowering of: `vector.maskedload` to `vector.load` +/// and `arith.select` if the memref is in buffer address space. +static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, + vector::MaskedLoadOp maskedOp) { + auto memRefType = dyn_cast(maskedOp.getBase().getType()); + if (!memRefType) + return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); + + Attribute addrSpace = memRefType.getMemorySpace(); + if (!isa_and_nonnull(addrSpace)) + return rewriter.notifyMatchFailure(maskedOp, "no address space"); + + if (dyn_cast(addrSpace).getValue() != + amdgpu::AddressSpace::FatRawBuffer) + return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); + + return success(); +} + +static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, + vector::MaskedLoadOp maskedOp) { + VectorType vectorType = maskedOp.getVectorType(); + Value load = builder.create( + loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); + Value res = builder.create( + loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru()); + return res; +} + +static constexpr char kMaskedloadNeedsMask[] = + "amdgpu.buffer_maskedload_needs_mask"; + +namespace { + +struct MaskedLoadLowering final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp, + PatternRewriter &rewriter) const override { + if (maskedOp->hasAttr(kMaskedloadNeedsMask)) + return failure(); + + if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { + return failure(); + } + + Location loc = maskedOp.getLoc(); + Value src = maskedOp.getBase(); + + VectorType vectorType = maskedOp.getVectorType(); + int64_t vectorSize = vectorType.getNumElements(); + int64_t elementBitWidth = vectorType.getElementTypeBitWidth(); + SmallVector indices = maskedOp.getIndices(); + + auto stridedMetadata = + rewriter.create(loc, src); + SmallVector strides = + stridedMetadata.getConstifiedMixedStrides(); + SmallVector sizes = stridedMetadata.getConstifiedMixedSizes(); + OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset(); + memref::LinearizedMemRefInfo linearizedInfo; + OpFoldResult linearizedIndices; + std::tie(linearizedInfo, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth, + elementBitWidth, offset, sizes, + strides, indices); + + // delta = bufferSize - linearizedOffset + Value vectorSizeOffset = + rewriter.create(loc, vectorSize); + Value linearIndex = + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); + Value totalSize = getValueOrCreateConstantIndexOp( + rewriter, loc, linearizedInfo.linearizedSize); + Value delta = rewriter.create(loc, totalSize, linearIndex); + + // 1) check if delta < vectorSize + Value isOutofBounds = rewriter.create( + loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); + + // 2) check if (detla % elements_per_word != 0) + Value elementsPerWord = rewriter.create( + loc, llvm::divideCeil(32, elementBitWidth)); + Value isNotWordAligned = rewriter.create( + loc, arith::CmpIPredicate::ne, + rewriter.create(loc, delta, elementsPerWord), + rewriter.create(loc, 0)); + + // We take the fallback of maskedload default lowering only it is both + // out-of-bounds and not word aligned. The fallback ensures correct results + // when loading at the boundary of the buffer since buffer load returns + // inconsistent zeros for the whole word when boundary is crossed. + Value ifCondition = + rewriter.create(loc, isOutofBounds, isNotWordAligned); + + auto thenBuilder = [&](OpBuilder &builder, Location loc) { + Operation *read = builder.clone(*maskedOp.getOperation()); + read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr()); + Value readResult = read->getResult(0); + builder.create(loc, readResult); + }; + + auto elseBuilder = [&](OpBuilder &builder, Location loc) { + Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp); + rewriter.create(loc, res); + }; + + auto ifOp = + rewriter.create(loc, ifCondition, thenBuilder, elseBuilder); + + rewriter.replaceOp(maskedOp, ifOp); + + return success(); + } +}; + +} // namespace + +void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} + +struct AmdgpuMaskedloadToLoadPass final + : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateAmdgpuMaskedloadToLoadPatterns(patterns); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } + } +}; diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp deleted file mode 100644 index f5b12a9524cc9..0000000000000 --- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp +++ /dev/null @@ -1,239 +0,0 @@ -//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" - -#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/MathExtras.h" - -namespace mlir::amdgpu { -#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS -#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" -} // namespace mlir::amdgpu - -using namespace mlir; -using namespace mlir::amdgpu; - -/// This pattern supports lowering of: -/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and -/// `vector.broadcast` if all of the following hold: -/// - The transfer op is masked. -/// - The memref is in buffer address space. -/// - Stride of most minor memref dimension must be 1. -/// - Out-of-bounds masking is not required. -/// - If the memref's element type is a vector type then it coincides with the -/// result type. -/// - The permutation map doesn't perform permutation (broadcasting is allowed). -/// Note: those conditions mostly come from TransferReadToVectorLoadLowering -/// pass. -static LogicalResult transferPreconditions( - PatternRewriter &rewriter, VectorTransferOpInterface xferOp, - bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) { - if (!xferOp.getMask()) - return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer"); - - // Permutations are handled by VectorToSCF or - // populateVectorTransferPermutationMapLoweringPatterns. - // We let the 0-d corner case pass-through as it is supported. - SmallVector broadcastedDims; - if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting( - &broadcastedDims)) - return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast"); - - auto memRefType = dyn_cast(xferOp.getShapedType()); - if (!memRefType) - return rewriter.notifyMatchFailure(xferOp, "not a memref source"); - - Attribute addrSpace = memRefType.getMemorySpace(); - if (!isa_and_nonnull(addrSpace)) - return rewriter.notifyMatchFailure(xferOp, "no address space"); - - if (dyn_cast(addrSpace).getValue() != - amdgpu::AddressSpace::FatRawBuffer) - return rewriter.notifyMatchFailure(xferOp, "not in buffer address space"); - - // Non-unit strides are handled by VectorToSCF. - if (!memRefType.isLastDimUnitStride()) - return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF"); - - if (memRefType.getElementTypeBitWidth() < 8) - return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type"); - - // If there is broadcasting involved then we first load the unbroadcasted - // vector, and then broadcast it with `vector.broadcast`. - ArrayRef vectorShape = xferOp.getVectorType().getShape(); - SmallVector unbroadcastedVectorShape(vectorShape); - for (unsigned i : broadcastedDims) - unbroadcastedVectorShape[i] = 1; - unbroadcastedVectorType = xferOp.getVectorType().cloneWith( - unbroadcastedVectorShape, xferOp.getVectorType().getElementType()); - requiresBroadcasting = !broadcastedDims.empty(); - - // `vector.load` supports vector types as memref's elements only when the - // resulting vector type is the same as the element type. - auto memrefElTy = memRefType.getElementType(); - if (isa(memrefElTy) && memrefElTy != unbroadcastedVectorType) - return rewriter.notifyMatchFailure(xferOp, "incompatible element type"); - - // Otherwise, element types of the memref and the vector must match. - if (!isa(memrefElTy) && - memrefElTy != xferOp.getVectorType().getElementType()) - return rewriter.notifyMatchFailure(xferOp, "non-matching element type"); - - // Out-of-bounds dims are handled by MaterializeTransferMask. - if (xferOp.hasOutOfBoundsDim()) - return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask"); - - if (xferOp.getVectorType().getRank() != 1) - // vector.maskedload operates on 1-D vectors. - return rewriter.notifyMatchFailure( - xferOp, "vector type is not rank 1, can't create masked load, needs " - "VectorToSCF"); - - return success(); -} - -static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, - vector::TransferReadOp readOp, - bool requiresBroadcasting, - VectorType unbroadcastedVectorType) { - Value fill = builder.create(loc, unbroadcastedVectorType, - readOp.getPadding()); - Value load = builder.create( - loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices()); - Value res = builder.create(loc, unbroadcastedVectorType, - readOp.getMask(), load, fill); - // Insert a broadcasting op if required. - if (requiresBroadcasting) { - res = builder.create(loc, readOp.getVectorType(), res); - } - return res; -} - -static constexpr char kTransferReadNeedsMask[] = - "amdgpu.buffer_transfer_read_needs_mask"; - -namespace { - -struct TransferReadLowering final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferReadOp readOp, - PatternRewriter &rewriter) const override { - if (readOp->hasAttr(kTransferReadNeedsMask)) - return failure(); - - bool requiresBroadcasting = false; - VectorType unbroadcastedVectorType; - if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting, - unbroadcastedVectorType))) { - return failure(); - } - - Location loc = readOp.getLoc(); - Value src = readOp.getBase(); - - VectorType vectorType = readOp.getVectorType(); - int64_t vectorSize = vectorType.getNumElements(); - int64_t elementBitWidth = vectorType.getElementTypeBitWidth(); - SmallVector indices = readOp.getIndices(); - - auto stridedMetadata = - rewriter.create(loc, src); - SmallVector strides = - stridedMetadata.getConstifiedMixedStrides(); - SmallVector sizes = stridedMetadata.getConstifiedMixedSizes(); - OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset(); - memref::LinearizedMemRefInfo linearizedInfo; - OpFoldResult linearizedIndices; - std::tie(linearizedInfo, linearizedIndices) = - memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth, - elementBitWidth, offset, sizes, - strides, indices); - - // delta = bufferSize - linearizedOffset - Value vectorSizeOffset = - rewriter.create(loc, vectorSize); - Value linearIndex = - getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); - Value totalSize = getValueOrCreateConstantIndexOp( - rewriter, loc, linearizedInfo.linearizedSize); - Value delta = rewriter.create(loc, totalSize, linearIndex); - - // 1) check if delta < vectorSize - Value isOutofBounds = rewriter.create( - loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); - - // 2) check if (detla % elements_per_word != 0) - Value elementsPerWord = rewriter.create( - loc, llvm::divideCeil(32, elementBitWidth)); - Value isNotWordAligned = rewriter.create( - loc, arith::CmpIPredicate::ne, - rewriter.create(loc, delta, elementsPerWord), - rewriter.create(loc, 0)); - - // We take the fallback of transfer_read default lowering only it is both - // out-of-bounds and not word aligned. The fallback ensures correct results - // when loading at the boundary of the buffer since buffer load returns - // inconsistent zeros for the whole word when boundary is crossed. - Value ifCondition = - rewriter.create(loc, isOutofBounds, isNotWordAligned); - - auto thenBuilder = [&](OpBuilder &builder, Location loc) { - Operation *read = builder.clone(*readOp.getOperation()); - read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr()); - Value readResult = read->getResult(0); - builder.create(loc, readResult); - }; - - auto elseBuilder = [&](OpBuilder &builder, Location loc) { - Value res = createVectorLoadForMaskedLoad( - builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType); - rewriter.create(loc, res); - }; - - auto ifOp = - rewriter.create(loc, ifCondition, thenBuilder, elseBuilder); - - rewriter.replaceOp(readOp, ifOp); - - return success(); - } -}; - -} // namespace - -void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} - -struct AmdgpuTransferReadToLoadPass final - : amdgpu::impl::AmdgpuTransferReadToLoadPassBase< - AmdgpuTransferReadToLoadPass> { - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateAmdgpuTransferReadToLoadPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { - return signalPassFailure(); - } - } -}; diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir similarity index 56% rename from mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir rename to mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index 20999af10553e..febe46bf7a759 100644 --- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -1,17 +1,17 @@ -// RUN: mlir-opt %s --amdgpu-transfer-read-to-load --split-input-file | FileCheck %s +// RUN: mlir-opt %s --amdgpu-maskedload-to-load --split-input-file | FileCheck %s // CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer( // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-SAME: %[[ARG2:.*]]: vector<4xi1> -func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space>, vector<4xf32> +// CHECK-SAME: %[[ARG3:.*]]: vector<4xf32> +func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> { + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space>, vector<4xi1>, vector<4xf32> into vector<4xf32> return %res : vector<4xf32> } // CHECK: %[[IF:.*]] = scf.if -// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]] +// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]] // CHECK: } else { // CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1] @@ -25,10 +25,10 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad // CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_f16( // CHECK-SAME: %[[ARG0:.+]]: memref<8x8xf16, #amdgpu.address_space>, // CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, -// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1>) -func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xf16> { - %cf0 = arith.constant 0.0 : f16 - %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref<8x8xf16, #amdgpu.address_space>, vector<4xf16> +// CHECK-SAME: %[[ARG3:.+]]: vector<4xi1> +// CHECK-SAME: %[[ARG4:.+]]: vector<4xf16> +func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgpu.address_space>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>, %passthru : vector<4xf16>) -> vector<4xf16> { + %res = vector.maskedload %mem[%idx0, %idx1], %mask, %passthru : memref<8x8xf16, #amdgpu.address_space>, vector<4xi1>, vector<4xf16> into vector<4xf16> return %res : vector<4xf16> } // CHECK-DAG: %[[C0:.*]] = arith.constant 0 @@ -45,7 +45,7 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp // CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]] // CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) { -// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG2]]] +// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]] // CHECK: } else { // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] // CHECK: return %[[IF]] : vector<4xf16> @@ -58,13 +58,11 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp // CHECK-SAME: %[[ARG0:.*]]: memref> // CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index // CHECK-SAME: %[[ARG3:.*]]: vector<4xi1> -func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>) -> vector<4xi8> { - %cf0 = arith.constant 0 : i8 - %res = vector.transfer_read %mem[%idx0, %idx1], %cf0, %mask {in_bounds = [true]} : memref>, vector<4xi8> +// CHECK-SAME: %[[ARG4:.*]]: vector<4xi8> +func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref>, %idx0 : index, %idx1 : index, %mask : vector<4xi1>, %passthru : vector<4xi8>) -> vector<4xi8> { + %res = vector.maskedload %mem[%idx0, %idx1], %mask, %passthru : memref>, vector<4xi1>, vector<4xi8> into vector<4xi8> return %res : vector<4xi8> } - -// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] @@ -79,13 +77,12 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-SAME: %[[ARG2:.*]]: vector<4xi1> -func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32> +// CHECK-SAME: %[[ARG3:.*]]: vector<4xf32> +func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> { + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> return %res : vector<4xf32> } -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 -// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]] +// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[ARG3]] // CHECK: return %[[RES]] : vector<4xf32> // ----- @@ -94,49 +91,26 @@ func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #gpu.address_space> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-SAME: %[[ARG2:.*]]: vector<4xi1> -func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space>, vector<4xf32> +// CHECK-SAME: %[[ARG3:.*]]: vector<4xf32> +func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> { + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #gpu.address_space>, vector<4xi1>, vector<4xf32> into vector<4xf32> return %res : vector<4xf32> } -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 -// CHECK: %[[RES:.*]] = vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[CST]], %[[ARG2]] +// CHECK: %[[RES:.*]] = vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]], %[[ARG2]], %[[ARG3]] // CHECK: return %[[RES]] : vector<4xf32> // ----- -// CHECK-LABEL: func @transfer_broadcasting( -// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space> -// CHECK-SAME: %[[ARG1:.*]]: index -// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1> -#broadcast_1d = affine_map<(d0, d1) -> (0)> -func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask - {in_bounds = [true], permutation_map = #broadcast_1d} - : memref<8x8xf32, #amdgpu.address_space>, vector<4xf32> - return %res : vector<4xf32> -} -// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> -// CHECK: %[[IF:.*]] = scf.if -// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1] -// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]] -// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32> - -// ----- - // CHECK-LABEL: func @transfer_scalar( // CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-SAME: %[[ARG2:.*]]: vector<1xi1> -func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask - {in_bounds = [true]} - : memref<8x8xf32, #amdgpu.address_space>, vector<1xf32> +// CHECK-SAME: %[[ARG3:.*]]: vector<1xf32> +func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<1xi1>, %passthru : vector<1xf32>) -> vector<1xf32> { + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru + : memref<8x8xf32, #amdgpu.address_space>, vector<1xi1>, vector<1xf32> into vector<1xf32> return %res : vector<1xf32> } -// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> // CHECK: %[[IF:.*]] = scf.if // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]] -// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]] +// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[ARG3]]