diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h index 5f40315a84909..094360e75ab61 100644 --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -8,7 +8,7 @@ #ifndef MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_ #define MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_ -#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/GPU/Utils/GPUUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Types.h" diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h index 8eb711962583d..aaef91f31ab9c 100644 --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h @@ -13,8 +13,8 @@ #ifndef MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_ -#include "Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Utils/GPUUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h new file mode 100644 index 0000000000000..ff8840a769779 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h @@ -0,0 +1,59 @@ +//===- DistributionUtils.h - Distribution Utilities -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_ +#define MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBITIONUTILS_H_ + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" + +namespace mlir::gpu { +struct WarpDistributionPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + using Base = WarpDistributionPattern; + + virtual LogicalResult + matchAndRewrite(WarpExecuteOnLane0Op op, + PatternRewriter &rewriter) const override = 0; + +protected: + /// Return a value yielded by `warpOp` which statifies the filter lamdba + /// condition and is not dead. + OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, + llvm::function_ref fn) const; + + /// Helper to create a new WarpExecuteOnLane0Op with different signature. + WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( + RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes) const; + + /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. + /// `indices` return the index of each new output. + WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( + RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes, + SmallVector &indices) const; + + /// Delinearize the given `laneId` into multiple dimensions, where each + /// dimension's size is determined by `originalShape` and `distributedShape` + /// together. This function expects the total numbers of threads needed for + /// distribution is equal to `warpSize`. Returns true and updates + /// `delinearizedIds` if so. + bool delinearizeLaneId(OpBuilder &builder, Location loc, + ArrayRef originalShape, + ArrayRef distributedShape, int64_t warpSize, + Value laneId, + SmallVectorImpl &delinearizedIds) const; +}; + +} // namespace mlir::gpu + +#endif // MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_ diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/GPU/Utils/GPUUtils.h similarity index 100% rename from mlir/include/mlir/Dialect/GPU/Transforms/Utils.h rename to mlir/include/mlir/Dialect/GPU/Utils/GPUUtils.h diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt index a59645480aba2..1026e9b509332 100644 --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -40,7 +40,6 @@ add_mlir_dialect_library(MLIRGPUTransforms Transforms/ShuffleRewriter.cpp Transforms/SPIRVAttachTarget.cpp Transforms/SubgroupReduceLowering.cpp - Transforms/Utils.cpp OBJECT @@ -59,6 +58,7 @@ add_mlir_dialect_library(MLIRGPUTransforms MLIRDataLayoutInterfaces MLIRExecutionEngineUtils MLIRGPUDialect + MLIRGPUUtils MLIRIR MLIRIndexDialect MLIRLLVMDialect @@ -76,3 +76,4 @@ add_mlir_dialect_library(MLIRGPUTransforms add_subdirectory(TransformOps) add_subdirectory(Pipelines) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index b2fa3a99c53fc..41a5e39e55064 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -16,7 +16,7 @@ #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/GPU/Utils/GPUUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index ba0c80c50211e..a6a36848b5635 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -18,7 +18,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/GPU/Utils/GPUUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 185f824351a23..43eff3eddcc49 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -13,7 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/GPU/Utils/GPUUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" diff --git a/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt new file mode 100644 index 0000000000000..69094c518a159 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRGPUUtils + Utils.cpp + DistributionUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU/Utils + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRAffineDialect + MLIRGPUDialect + MLIRSupport + MLIRIR + ) diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp new file mode 100644 index 0000000000000..9d51ac3fc4bdc --- /dev/null +++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp @@ -0,0 +1,144 @@ +//===- DistributionUtils.cpp - Distribution tools for GPUOps --------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements distribution utility methods. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/Utils/DistributionUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Value.h" + +#include + +using namespace mlir; +using namespace mlir::gpu; + +WarpExecuteOnLane0Op +WarpDistributionPattern::moveRegionToNewWarpOpAndReplaceReturns( + RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes) const { + // Create a new op before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(warpOp); + auto newWarpOp = rewriter.create( + warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), + warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); + + Region &opBody = warpOp.getBodyRegion(); + Region &newOpBody = newWarpOp.getBodyRegion(); + Block &newOpFirstBlock = newOpBody.front(); + rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); + rewriter.eraseBlock(&newOpFirstBlock); + assert(newWarpOp.getWarpRegion().hasOneBlock() && + "expected WarpOp with single block"); + + auto yield = + cast(newOpBody.getBlocks().begin()->getTerminator()); + + rewriter.modifyOpInPlace( + yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); }); + return newWarpOp; +} + +WarpExecuteOnLane0Op +WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns( + RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes, + SmallVector &indices) const { + SmallVector types(warpOp.getResultTypes().begin(), + warpOp.getResultTypes().end()); + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + llvm::SmallSetVector yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) { + if (yieldValues.insert(value)) { + types.push_back(type); + indices.push_back(yieldValues.size() - 1); + } else { + // If the value already exit the region don't create a new output. + for (auto [idx, yieldOperand] : + llvm::enumerate(yieldValues.getArrayRef())) { + if (yieldOperand == value) { + indices.push_back(idx); + break; + } + } + } + } + yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( + rewriter, warpOp, yieldValues.getArrayRef(), types); + rewriter.replaceOp(warpOp, + newWarpOp.getResults().take_front(warpOp.getNumResults())); + return newWarpOp; +} + +OpOperand *WarpDistributionPattern::getWarpResult( + WarpExecuteOnLane0Op warpOp, + llvm::function_ref fn) const { + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + for (OpOperand &yieldOperand : yield->getOpOperands()) { + Value yieldValues = yieldOperand.get(); + Operation *definedOp = yieldValues.getDefiningOp(); + if (definedOp && fn(definedOp)) { + if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) + return &yieldOperand; + } + } + return nullptr; +} + +bool WarpDistributionPattern::delinearizeLaneId( + OpBuilder &builder, Location loc, ArrayRef originalShape, + ArrayRef distributedShape, int64_t warpSize, Value laneId, + SmallVectorImpl &delinearizedIds) const { + // If the original shape and the distributed shape is the same, we don't + // distribute at all--every thread is handling the whole. For such case, we + // should not rely on lane IDs later. So just return an empty lane ID vector. + if (originalShape == distributedShape) { + delinearizedIds.clear(); + return true; + } + + SmallVector sizes; + for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) { + if (large % small != 0) + return false; + sizes.push_back(large / small); + } + if (std::accumulate(sizes.begin(), sizes.end(), 1, + std::multiplies()) != warpSize) + return false; + + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + + int64_t usedThreads = 1; + + Value zero = builder.create(loc, 0); + delinearizedIds.assign(sizes.size(), zero); + + for (int i = sizes.size() - 1; i >= 0; --i) { + usedThreads *= sizes[i]; + if (usedThreads == warpSize) { + // We've used up all available threads. Don't need to perform modulo + // anymore. And we can stop the calculation for further dimensions. + delinearizedIds[i] = laneId; + break; + } + delinearizedIds[i] = + affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId}); + laneId = affine::makeComposedAffineApply( + builder, loc, s0.floorDiv(usedThreads), {laneId}); + } + return true; +} diff --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Utils/Utils.cpp similarity index 96% rename from mlir/lib/Dialect/GPU/Transforms/Utils.cpp rename to mlir/lib/Dialect/GPU/Utils/Utils.cpp index e91aa18128c7b..1f09875b3e273 100644 --- a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/GPU/Utils/Utils.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/GPU/Utils/GPUUtils.h" #include "llvm/Support/ErrorHandling.h" namespace mlir::gpu { diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 9a3bd5d4593d6..8ca5cb6c6dfab 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRVectorTransforms MLIRArithDialect MLIRDialectUtils MLIRGPUDialect + MLIRGPUUtils MLIRIR MLIRLinalgDialect MLIRMemRefDialect diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 3e14259836995..e214257de2cdf 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Utils/DistributionUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -18,7 +19,6 @@ #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/FormatVariadic.h" -#include #include using namespace mlir; @@ -162,92 +162,6 @@ struct DistributedLoadStoreHelper { } // namespace -/// Helper to create a new WarpExecuteOnLane0Op with different signature. -static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( - RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, - ValueRange newYieldedValues, TypeRange newReturnTypes) { - // Create a new op before the existing one, with the extra operands. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(warpOp); - auto newWarpOp = rewriter.create( - warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), - warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); - - Region &opBody = warpOp.getBodyRegion(); - Region &newOpBody = newWarpOp.getBodyRegion(); - Block &newOpFirstBlock = newOpBody.front(); - rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); - rewriter.eraseBlock(&newOpFirstBlock); - assert(newWarpOp.getWarpRegion().hasOneBlock() && - "expected WarpOp with single block"); - - auto yield = - cast(newOpBody.getBlocks().begin()->getTerminator()); - - rewriter.modifyOpInPlace( - yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); }); - return newWarpOp; -} - -/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. -/// `indices` return the index of each new output. -static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( - RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, - ValueRange newYieldedValues, TypeRange newReturnTypes, - llvm::SmallVector &indices) { - SmallVector types(warpOp.getResultTypes().begin(), - warpOp.getResultTypes().end()); - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - llvm::SmallSetVector yieldValues(yield.getOperands().begin(), - yield.getOperands().end()); - for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { - if (yieldValues.insert(std::get<0>(newRet))) { - types.push_back(std::get<1>(newRet)); - indices.push_back(yieldValues.size() - 1); - } else { - // If the value already exit the region don't create a new output. - for (auto [idx, yieldOperand] : - llvm::enumerate(yieldValues.getArrayRef())) { - if (yieldOperand == std::get<0>(newRet)) { - indices.push_back(idx); - break; - } - } - } - } - yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues.getArrayRef(), types); - rewriter.replaceOp(warpOp, - newWarpOp.getResults().take_front(warpOp.getNumResults())); - return newWarpOp; -} - -/// Helper to know if an op can be hoisted out of the region. -static bool canBeHoisted(Operation *op, - function_ref definedOutside) { - return llvm::all_of(op->getOperands(), definedOutside) && - isMemoryEffectFree(op) && op->getNumRegions() == 0; -} - -/// Return a value yielded by `warpOp` which statifies the filter lamdba -/// condition and is not dead. -static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, - const std::function &fn) { - auto yield = cast( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - for (OpOperand &yieldOperand : yield->getOpOperands()) { - Value yieldValues = yieldOperand.get(); - Operation *definedOp = yieldValues.getDefiningOp(); - if (definedOp && fn(definedOp)) { - if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) - return &yieldOperand; - } - } - return {}; -} - // Clones `op` into a new operation that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, @@ -289,12 +203,11 @@ namespace { /// /// All this assumes the vector distribution occurs along the most minor /// distributed vector dimension. -struct WarpOpToScfIfPattern : public OpRewritePattern { +struct WarpOpToScfIfPattern : public WarpDistributionPattern { WarpOpToScfIfPattern(MLIRContext *context, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - options(options) {} + : WarpDistributionPattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { @@ -402,39 +315,6 @@ struct WarpOpToScfIfPattern : public OpRewritePattern { const WarpExecuteOnLane0LoweringOptions &options; }; -/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute -/// op with the proper return type. -/// The new write op is updated to write the result of the new warp execute op. -/// The old `writeOp` is deleted. -static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, - WarpExecuteOnLane0Op warpOp, - vector::TransferWriteOp writeOp, - VectorType targetType, - VectorType maybeMaskType) { - assert(writeOp->getParentOp() == warpOp && - "write must be nested immediately under warp"); - OpBuilder::InsertionGuard g(rewriter); - SmallVector newRetIndices; - WarpExecuteOnLane0Op newWarpOp; - if (maybeMaskType) { - newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, - TypeRange{targetType, maybeMaskType}, newRetIndices); - } else { - newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, ValueRange{{writeOp.getVector()}}, - TypeRange{targetType}, newRetIndices); - } - rewriter.setInsertionPointAfter(newWarpOp); - auto newWriteOp = - cast(rewriter.clone(*writeOp.getOperation())); - rewriter.eraseOp(writeOp); - newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); - if (maybeMaskType) - newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); - return newWriteOp; -} - /// Return the distributed vector type based on the original type and the /// distribution map. The map is expected to have a dimension equal to the /// original type rank and should be a projection where the results are the @@ -487,11 +367,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map, /// gpu.yield %v : vector<32xf32> /// } /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> -struct WarpOpTransferWrite : public OpRewritePattern { +struct WarpOpTransferWrite : public WarpDistributionPattern { WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, unsigned maxNumElementsToExtract, PatternBenefit b = 1) - : OpRewritePattern(ctx, b), - distributionMapFn(std::move(fn)), + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)), maxNumElementsToExtract(maxNumElementsToExtract) {} /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that @@ -654,6 +533,38 @@ struct WarpOpTransferWrite : public OpRewritePattern { } private: + /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp + /// execute op with the proper return type. The new write op is updated to + /// write the result of the new warp execute op. The old `writeOp` is deleted. + vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, + WarpExecuteOnLane0Op warpOp, + vector::TransferWriteOp writeOp, + VectorType targetType, + VectorType maybeMaskType) const { + assert(writeOp->getParentOp() == warpOp && + "write must be nested immediately under warp"); + OpBuilder::InsertionGuard g(rewriter); + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp; + if (maybeMaskType) { + newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, + TypeRange{targetType, maybeMaskType}, newRetIndices); + } else { + newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, ValueRange{{writeOp.getVector()}}, + TypeRange{targetType}, newRetIndices); + } + rewriter.setInsertionPointAfter(newWarpOp); + auto newWriteOp = + cast(rewriter.clone(*writeOp.getOperation())); + rewriter.eraseOp(writeOp); + newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); + if (maybeMaskType) + newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); + return newWriteOp; + } + DistributionMapFn distributionMapFn; unsigned maxNumElementsToExtract = 1; }; @@ -676,8 +587,8 @@ struct WarpOpTransferWrite : public OpRewritePattern { /// vector<32xf32> /// } /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> -struct WarpOpElementwise : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpElementwise : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { @@ -742,8 +653,8 @@ struct WarpOpElementwise : public OpRewritePattern { /// ... /// } /// %0 = arith.constant dense<2.0> : vector<1xf32> -struct WarpOpConstant : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpConstant : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *yieldOperand = @@ -770,57 +681,6 @@ struct WarpOpConstant : public OpRewritePattern { } }; -/// Delinearize the given `laneId` into multiple dimensions, where each -/// dimension's size is determined by `originalShape` and `distributedShape` -/// together. This function expects the total numbers of threads needed for -/// distribution is equal to `warpSize`. Returns true and updates -/// `delinearizedIds` if so. -bool delinearizeLaneId(OpBuilder &builder, Location loc, - ArrayRef originalShape, - ArrayRef distributedShape, int64_t warpSize, - Value laneId, SmallVectorImpl &delinearizedIds) { - // If the original shape and the distributed shape is the same, we don't - // distribute at all--every thread is handling the whole. For such case, we - // should not rely on lane IDs later. So just return an empty lane ID vector. - if (originalShape == distributedShape) { - delinearizedIds.clear(); - return true; - } - - SmallVector sizes; - for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) { - if (large % small != 0) - return false; - sizes.push_back(large / small); - } - if (std::accumulate(sizes.begin(), sizes.end(), 1, - std::multiplies()) != warpSize) - return false; - - AffineExpr s0, s1; - bindSymbols(builder.getContext(), s0, s1); - - int64_t usedThreads = 1; - - Value zero = builder.create(loc, 0); - delinearizedIds.assign(sizes.size(), zero); - - for (int i = sizes.size() - 1; i >= 0; --i) { - usedThreads *= sizes[i]; - if (usedThreads == warpSize) { - // We've used up all available threads. Don't need to perform modulo - // anymore. And we can stop the calculation for further dimensions. - delinearizedIds[i] = laneId; - break; - } - delinearizedIds[i] = - affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId}); - laneId = affine::makeComposedAffineApply( - builder, loc, s0.floorDiv(usedThreads), {laneId}); - } - return true; -} - /// Sink out transfer_read op feeding into a warp op yield. /// ``` /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { @@ -839,8 +699,8 @@ bool delinearizeLaneId(OpBuilder &builder, Location loc, /// vector<32xf32> gpu.yield %2 : vector<32xf32> /// } /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> -struct WarpOpTransferRead : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpTransferRead : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { // Try to find a distributable yielded read. Note that this pattern can @@ -951,8 +811,8 @@ struct WarpOpTransferRead : public OpRewritePattern { /// Remove any result that has no use along with the matching yieldOp operand. // TODO: Move this in WarpExecuteOnLane0Op canonicalization. -struct WarpOpDeadResult : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpDeadResult : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { SmallVector newResultTypes; @@ -1012,8 +872,8 @@ struct WarpOpDeadResult : public OpRewritePattern { // If an operand is directly yielded out of the region we can forward it // directly and it doesn't need to go through the region. -struct WarpOpForwardOperand : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpForwardOperand : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { SmallVector resultTypes; @@ -1056,8 +916,8 @@ struct WarpOpForwardOperand : public OpRewritePattern { } }; -struct WarpOpBroadcast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpBroadcast : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -1093,8 +953,8 @@ struct WarpOpBroadcast : public OpRewritePattern { /// Pattern to move shape cast out of the warp op. shape cast is basically a /// no-op for warp distribution; we need to handle the shape though. -struct WarpOpShapeCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpShapeCast : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -1152,8 +1012,8 @@ struct WarpOpShapeCast : public OpRewritePattern { /// %cmp = arith.cmpi ult, %laneid, %0 /// %ub = arith.select %cmp, %c0, %c1 /// %1 = vector.create_mask %ub : vector<1xi1> -struct WarpOpCreateMask : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpCreateMask : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *yieldOperand = @@ -1218,8 +1078,8 @@ struct WarpOpCreateMask : public OpRewritePattern { /// Pattern to move out vector.extract of single element vector. Those don't /// need to be distributed and can just be propagated outside of the region. -struct WarpOpExtract : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpExtract : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -1298,11 +1158,10 @@ struct WarpOpExtract : public OpRewritePattern { /// Pattern to move out vector.extract with a scalar result. /// Only supports 1-D and 0-D sources for now. -struct WarpOpExtractScalar : public OpRewritePattern { +struct WarpOpExtractScalar : public WarpDistributionPattern { WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn, PatternBenefit b = 1) - : OpRewritePattern(ctx, b), - warpShuffleFromIdxFn(std::move(fn)) {} + : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -1397,9 +1256,8 @@ struct WarpOpExtractScalar : public OpRewritePattern { }; /// Pattern to convert vector.extractelement to vector.extract. -struct WarpOpExtractElement : public OpRewritePattern { - WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1) - : OpRewritePattern(ctx, b) {} +struct WarpOpExtractElement : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -1420,9 +1278,8 @@ struct WarpOpExtractElement : public OpRewritePattern { /// Pattern to move out vector.insert with a scalar input. /// Only supports 1-D and 0-D destinations for now. -struct WarpOpInsertScalar : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - +struct WarpOpInsertScalar : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); @@ -1513,9 +1370,8 @@ struct WarpOpInsertScalar : public OpRewritePattern { } }; -struct WarpOpInsert : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - +struct WarpOpInsert : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred); @@ -1627,9 +1483,8 @@ struct WarpOpInsert : public OpRewritePattern { } }; -struct WarpOpInsertElement : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - +struct WarpOpInsertElement : public WarpDistributionPattern { + using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = @@ -1680,12 +1535,10 @@ struct WarpOpInsertElement : public OpRewritePattern { /// scf.yield %iw : vector<4xf32> /// } /// ``` -struct WarpOpScfForOp : public OpRewritePattern { +struct WarpOpScfForOp : public WarpDistributionPattern { WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) - : OpRewritePattern(ctx, b), - distributionMapFn(std::move(fn)) {} - using OpRewritePattern::OpRewritePattern; + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { auto yield = cast( @@ -1824,11 +1677,11 @@ struct WarpOpScfForOp : public OpRewritePattern { /// %a = vector.extract %0[0] : f32 from vector<1xf32> /// %r = ("warp.reduction %a") /// ``` -struct WarpOpReduction : public OpRewritePattern { +struct WarpOpReduction : public WarpDistributionPattern { WarpOpReduction(MLIRContext *context, DistributedReductionFn distributedReductionFn, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : WarpDistributionPattern(context, benefit), distributedReductionFn(std::move(distributedReductionFn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, @@ -1927,6 +1780,13 @@ void mlir::vector::populateDistributeReduction( benefit); } +/// Helper to know if an op can be hoisted out of the region. +static bool canBeHoisted(Operation *op, + function_ref definedOutside) { + return llvm::all_of(op->getOperands(), definedOutside) && + isMemoryEffectFree(op) && op->getNumRegions() == 0; +} + void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { Block *body = warpOp.getBody();