diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h index c3ae7930e8ec8..94dd9e3a29331 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h @@ -22,6 +22,7 @@ namespace amdgpu { #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS #define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS +#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" @@ -30,6 +31,9 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, Chipset chipset); void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns); + +void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns); + } // namespace amdgpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td index 6d0bcd6e1066e..761caa448a57c 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td @@ -51,4 +51,18 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> { ]; } +def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> { + let summary = "Lower the operations from the vector transfer_read to vector load"; + let description = [{ + This pass creates a transfer read op lowering. A vector trasfer read op + will be lowered to a combination of vector.load, arith.select and + vector.broadcast. + + This pattern will make it possible for masked transfer_read to be lowered + towards buffer load with bounds check, allowing a more optimized global + load accessing pattern compared with existing implementation of + llvm.intr.masked.load on vectors. + }]; + let dependentDialects = []; +} #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_ diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index 3d4567bff1e32..bc5b6e9186449 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms EmulateAtomics.cpp ResolveStridedMetadata.cpp + TransferReadToLoad.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp new file mode 100644 index 0000000000000..3c1a2eb962037 --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp @@ -0,0 +1,154 @@ +//===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir::amdgpu { +#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" +} // namespace mlir::amdgpu + +using namespace mlir; +using namespace mlir::amdgpu; + +/// This pattern supports lowering of: +/// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and +/// `vector.broadcast` if all of the following hold: +/// - The transfer op is masked. +/// - The memref is in buffer address space. +/// - Stride of most minor memref dimension must be 1. +/// - Out-of-bounds masking is not required. +/// - If the memref's element type is a vector type then it coincides with the +/// result type. +/// - The permutation map doesn't perform permutation (broadcasting is allowed). +/// Note: those conditions mostly come from TransferReadToVectorLoadLowering +/// pass. +static LogicalResult transferPreconditions( + PatternRewriter &rewriter, VectorTransferOpInterface xferOp, + bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) { + if (!xferOp.getMask()) + return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer"); + + // Permutations are handled by VectorToSCF or + // populateVectorTransferPermutationMapLoweringPatterns. + // We let the 0-d corner case pass-through as it is supported. + SmallVector broadcastedDims; + if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting( + &broadcastedDims)) + return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast"); + + auto memRefType = dyn_cast(xferOp.getShapedType()); + if (!memRefType) + return rewriter.notifyMatchFailure(xferOp, "not a memref source"); + + Attribute addrSpace = memRefType.getMemorySpace(); + if (!addrSpace || !dyn_cast(addrSpace)) + return rewriter.notifyMatchFailure(xferOp, "no address space"); + + if (dyn_cast(addrSpace).getValue() != + amdgpu::AddressSpace::FatRawBuffer) + return rewriter.notifyMatchFailure(xferOp, "not in buffer address space"); + + // Non-unit strides are handled by VectorToSCF. + if (!memRefType.isLastDimUnitStride()) + return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF"); + + // If there is broadcasting involved then we first load the unbroadcasted + // vector, and then broadcast it with `vector.broadcast`. + ArrayRef vectorShape = xferOp.getVectorType().getShape(); + SmallVector unbroadcastedVectorShape(vectorShape); + for (unsigned i : broadcastedDims) + unbroadcastedVectorShape[i] = 1; + unbroadcastedVectorType = xferOp.getVectorType().cloneWith( + unbroadcastedVectorShape, xferOp.getVectorType().getElementType()); + requiresBroadcasting = !broadcastedDims.empty(); + + // `vector.load` supports vector types as memref's elements only when the + // resulting vector type is the same as the element type. + auto memrefElTy = memRefType.getElementType(); + if (isa(memrefElTy) && memrefElTy != unbroadcastedVectorType) + return rewriter.notifyMatchFailure(xferOp, "incompatible element type"); + + // Otherwise, element types of the memref and the vector must match. + if (!isa(memrefElTy) && + memrefElTy != xferOp.getVectorType().getElementType()) + return rewriter.notifyMatchFailure(xferOp, "non-matching element type"); + + // Out-of-bounds dims are handled by MaterializeTransferMask. + if (xferOp.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask"); + + if (xferOp.getVectorType().getRank() != 1) + // vector.maskedload operates on 1-D vectors. + return rewriter.notifyMatchFailure( + xferOp, "vector type is not rank 1, can't create masked load, needs " + "VectorToSCF"); + + return success(); +} + +namespace { + +struct TransferReadLowering final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const override { + + bool requiresBroadcasting = false; + VectorType unbroadcastedVectorType; + if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting, + unbroadcastedVectorType))) { + return failure(); + } + + Location loc = readOp.getLoc(); + Value fill = rewriter.create(loc, unbroadcastedVectorType, + readOp.getPadding()); + Value load = rewriter.create( + loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices()); + Value res = rewriter.create(loc, unbroadcastedVectorType, + readOp.getMask(), load, fill); + + // Insert a broadcasting op if required. + if (requiresBroadcasting) { + res = rewriter.create(loc, readOp.getVectorType(), + res); + } + + rewriter.replaceOp(readOp, res); + + return success(); + } +}; + +} // namespace + +void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +struct AmdgpuTransferReadToLoadPass final + : amdgpu::impl::AmdgpuTransferReadToLoadPassBase< + AmdgpuTransferReadToLoadPass> { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateAmdgpuTransferReadToLoadPatterns(patterns); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + } +}; diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir new file mode 100644 index 0000000000000..3e1283579f2b1 --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt %s --amdgpu-transfer-read-to-load --split-input-file | FileCheck %s + +// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1> +func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> { + %cf0 = arith.constant 0.0 : f32 + %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space>, vector<4xf32> + return %res : vector<4xf32> +} +// CHECK: %[[CST:.*]] = arith.constant 0.0 +// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]] +// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1] +// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]] +// CHECK: return %[[SELECT]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: func @transfer_to_maskedload_regular( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1> +func.func @transfer_to_maskedload_regular(%mem : memref<8x8xf32>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> { + %cf0 = arith.constant 0.0 : f32 + %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32> + return %res : vector<4xf32> +} +// CHECK: %[[CST:.*]] = arith.constant 0.0 +// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32> +// CHECK: return %[[RES]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: func @transfer_to_maskedload_addrspace( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #gpu.address_space> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1> +func.func @transfer_to_maskedload_addrspace(%mem : memref<8x8xf32, #gpu.address_space>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> { + %cf0 = arith.constant 0.0 : f32 + %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space>, vector<4xf32> + return %res : vector<4xf32> +} +// CHECK: %[[CST:.*]] = arith.constant 0.0 +// CHECK: %[[RES:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {in_bounds = [true]} : memref<8x8xf32, #gpu.address_space>, vector<4xf32> +// CHECK: return %[[RES]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: func @transfer_broadcasting( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1> +#broadcast_1d = affine_map<(d0, d1) -> (0)> +func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<1xi1>) -> vector<4xf32> { + %cf0 = arith.constant 0.0 : f32 + %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask + {in_bounds = [true], permutation_map = #broadcast_1d} + : memref<8x8xf32, #amdgpu.address_space>, vector<4xf32> + return %res : vector<4xf32> +} +// CHECK: %[[CST:.*]] = arith.constant 0.0 +// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]] +// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1] +// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]] +// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32> +// CHECK: return %[[BROADCAST]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: func @transfer_scalar( +// CHECK-SAME: %[[ARG0:.*]]: memref<8x8xf32, #amdgpu.address_space> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-SAME: %[[ARG2:.*]]: vector<1xi1> +func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<1xi1>) -> vector<1xf32> { + %cf0 = arith.constant 0.0 : f32 + %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask + {in_bounds = [true]} + : memref<8x8xf32, #amdgpu.address_space>, vector<1xf32> + return %res : vector<1xf32> +} +// CHECK: %[[CST:.*]] = arith.constant 0.0 +// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]] +// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1] +// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]] +// CHECK: return %[[SELECT]] : vector<1xf32>