diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 3205da6e448fc..668ee6386f71f 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction = /// Options to use to control tiling. struct SCFTilingOptions { + /// Specify which loop construct to use for tile and fuse. + enum class LoopType { ForOp, ForallOp, CustomOp }; + LoopType loopType = LoopType::ForOp; + SCFTilingOptions &setLoopType(LoopType type) { + loopType = type; + return *this; + } + /// Computation function that returns the tile sizes to use for each loop. /// Returning a tile size of zero implies no tiling for that loop. If the /// size of the returned vector is smaller than the number of loops, the inner @@ -50,6 +58,17 @@ struct SCFTilingOptions { /// proper interaction with folding. SCFTilingOptions &setTileSizes(ArrayRef tileSizes); + /// The interchange vector to reorder the tiled loops. + SmallVector interchangeVector = {}; + SCFTilingOptions &setInterchange(ArrayRef interchange) { + interchangeVector = llvm::to_vector(interchange); + return *this; + } + + //-------------------------------------------------------------------------// + // Options related to tiling using `scf.forall`. + //-------------------------------------------------------------------------// + /// Computation function that returns the number of threads to use for /// each loop. Returning a num threads of zero implies no tiling for that /// loop. If the size of the returned vector is smaller than the number of @@ -70,21 +89,6 @@ struct SCFTilingOptions { /// function that computes num threads at the point they are needed. SCFTilingOptions &setNumThreads(ArrayRef numThreads); - /// The interchange vector to reorder the tiled loops. - SmallVector interchangeVector = {}; - SCFTilingOptions &setInterchange(ArrayRef interchange) { - interchangeVector = llvm::to_vector(interchange); - return *this; - } - - /// Specify which loop construct to use for tile and fuse. - enum class LoopType { ForOp, ForallOp }; - LoopType loopType = LoopType::ForOp; - SCFTilingOptions &setLoopType(LoopType type) { - loopType = type; - return *this; - } - /// Specify mapping of loops to devices. This is only respected when the loop /// constructs support such a mapping (like `scf.forall`). Will be ignored /// when using loop constructs that dont support such a mapping (like @@ -117,6 +121,96 @@ struct SCFTilingOptions { reductionDims.insert(dims.begin(), dims.end()); return *this; } + + //-------------------------------------------------------------------------// + // Options related to tiling using custom loop. + //-------------------------------------------------------------------------// + + // For generating the inter-tile loops using a custom loop, two callback + // functions are needed + // 1. That generates the "loop header", i.e. the loop that iterates over the + // different tiles. + // 2. That generates the loop terminator + // + // For `scf.forall` case the call back to generate loop header would generate + // + // ```mlir + // scf.forall (...) = ... { + // .. + // } + // ``` + // + // and the call back to generate the loop terminator would generate the + // `scf.in_parallel` region + // + // ```mlir + // scf.forall (...) = ... { + // scf.in_parallel { + // tensor.parallel_insert_slice ... + // } + // } + // ``` + // + + // Information that is to be returned by loop header callback needed for the + // rest of the tiled codegeneration. + // - `loops`: The generated loops + // - `tileOffset`: The values that represent the offset of the iteration space + // tile. + // - `tileSizes` : The values that represent the size of the iteration space + // tile. + // - `destinationTensors` : The tensors to use as destinations during tiling. + struct CustomLoopHeaderInfo { + SmallVector loops; + SmallVector tileOffset; + SmallVector tileSizes; + SmallVector destinationTensors; + }; + + // Type of the callback function that generates the loop headers. + // - `loopRanges` : Values that represent the full size of the iteration space + // being tiled. + // - `givenTileSizes` : The tile sizes that are to be used to tile the + // iteration space. + // - `destinationTensors` : The tensors to use as destinations for the results + // of the tiled loop for loops that implement + // `DestinationStyleOpInterface`. + // Returns the `CustomLoopHeaderInfo` object (described above). it is expected + // that this function sets the insertion point of `rewriter` to the program + // point where the intra-tile loop computation is to be generated. + using GenerateLoopHeaderFn = std::function( + RewriterBase &rewriter, Location loc, ArrayRef loopRanges, + ArrayRef givenTileSizes, ValueRange destinationTensors)>; + + // Type of the callback function that generates the loop terminator. + // - `tiledResults` : Tiles of the result computed for the iteration space + // tile. + // - `resultOffsets` : For each of the `tiledResults`, the offset at which + // the result tile is to be "inserted" back into the + // destination tensor. + // - `resultSizes` : For each of the `tiledResults`, the size of the result + // tile that is to be "inserted" back into the destination + // tensor. + // Returns the `CustomLoopHeaderInfo` object (described above) + using GenerateLoopTerminatorFn = std::function> resultOffsets, + ArrayRef> resultSizes, + ValueRange destinationTensors)>; + + // Callback function to generate the inter-tile loop header. + GenerateLoopHeaderFn generateLoopHeaderFn = nullptr; + // Callback function to generate the inter-tile loop terminator. + GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr; + // Helper function to set the callbacks for inter-tile loop header and + // terminator functions when using a custom operation for the loop. + SCFTilingOptions & + setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn, + GenerateLoopTerminatorFn terminatorFn) { + generateLoopHeaderFn = std::move(headerFn); + generateLoopTerminatorFn = std::move(terminatorFn); + return *this; + } }; /// Transformation information returned after tiling. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f24310ecd7beb..89e2c57d709dd 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -155,18 +155,18 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, - ArrayRef tileSizes, + ArrayRef givenTileSizes, ArrayRef numThreads) { auto iterators = op.getLoopIteratorTypes(); - assert(iterators.size() == tileSizes.size() && + assert(iterators.size() == givenTileSizes.size() && "expected as many tile size values as number of loops"); assert((numThreads.empty() || (numThreads.size() == iterators.size())) && "when specified, expected number of threads to use for each loop"); bool isParallelTiling = false; - for (auto [index, iterator, tileSize] : - llvm::enumerate(iterators, tileSizes)) { - if (!isConstantIntValue(tileSize, 0)) { + for (auto [index, iterator, givenTileSize] : + llvm::enumerate(iterators, givenTileSizes)) { + if (!isConstantIntValue(givenTileSize, 0)) { isParallelTiling |= iterator == utils::IteratorType::parallel; } @@ -186,7 +186,7 @@ static LogicalResult checkTileSizes(TilingInterface op, } if (std::optional constTileSize = - getConstantIntValue(tileSize)) { + getConstantIntValue(givenTileSize)) { if (constTileSize.value() > 0 && iterator != utils::IteratorType::parallel) { op.emitWarning() << "tiling is not thread safe at axis #" << index; @@ -207,11 +207,11 @@ static LogicalResult checkTileSizes(TilingInterface op, /// Get the reduction dims that are tiled. This accounts for reduction dims /// that are specified as tiled, but the tile size is 0. static SetVector -getSanitizedReductionDims(ArrayRef tileSizes, +getSanitizedReductionDims(ArrayRef givenTileSizes, const scf::SCFTilingOptions &options) { SetVector reductionDims; for (auto dim : options.reductionDims) { - if (isConstantIntValue(tileSizes[dim], 0)) + if (isConstantIntValue(givenTileSizes[dim], 0)) continue; reductionDims.insert(dim); } @@ -236,14 +236,14 @@ static bool tileDividesIterationDomain(Range loopRange) { /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, - OpFoldResult tileSize) { - std::optional ts = getConstantIntValue(tileSize); + OpFoldResult givenTileSize) { + std::optional ts = getConstantIntValue(givenTileSize); if (ts && ts.value() == 1) - return tileSize; + return givenTileSize; if (tileDividesIterationDomain( - Range{loopRange.offset, loopRange.size, tileSize})) - return tileSize; + Range{loopRange.offset, loopRange.size, givenTileSize})) + return givenTileSize; // The tile size to use (to avoid out of bounds access) is minimum of // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled @@ -254,15 +254,15 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); return affine::makeComposedFoldedAffineMin( - b, loc, minMap, SmallVector{offset, size, tileSize}); + b, loc, minMap, SmallVector{offset, size, givenTileSize}); } /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less /// than `iterationSize`. -static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, +static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize) { - std::optional tileSizeConst = getConstantIntValue(tileSize); + std::optional tileSizeConst = getConstantIntValue(givenTileSize); std::optional numThreadsConst = getConstantIntValue(numThreads); std::optional iterSizeConst = getConstantIntValue(iterationSize); if (!tileSizeConst || !numThreadsConst || !iterSizeConst) @@ -274,114 +274,51 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, /// `offset`s and `size`s of the tile of the iteration space that the /// innermost loop body of the generated tiled loops corresponds to. static std::tuple, SmallVector> -getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, - ReductionTilingStrategy strategy, ValueRange ivs, +getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef iterationDomain, - ArrayRef tileSizes, - ArrayRef numThreads, - const llvm::SetVector &reductionDims) { + ArrayRef givenTileSizes) { SmallVector offsets, sizes; int materializedLoopNum = 0; - - if (!numThreads.empty()) { - AffineExpr d0, d1, s0, s1; - AffineExpr offsetExpr, residualTileSizeExpr; - bindDims(rewriter.getContext(), d0, d1); - bindSymbols(rewriter.getContext(), s0, s1); - offsetExpr = d0 + d1 * s0; - residualTileSizeExpr = s1 - (d0 + d1 * s0); - - for (auto [index, nt, tileSize, loopRange] : - llvm::enumerate(numThreads, tileSizes, iterationDomain)) { - - // Non-tiled cases, set the offset and size to the - // `loopRange.offset/size`. - if (isZeroInteger(nt)) { - offsets.push_back(loopRange.offset); - sizes.push_back(loopRange.size); - continue; - } - - Value iv = ivs[materializedLoopNum++]; - OpFoldResult offset = affine::makeComposedFoldedAffineApply( - rewriter, loc, offsetExpr, - ArrayRef{loopRange.offset, iv, tileSize}); - OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( - rewriter, loc, residualTileSizeExpr, - {loopRange.offset, nt, tileSize, loopRange.size}); - - OpFoldResult size = tileSize; - if (!isZeroInteger(residualTileSize)) { - OpFoldResult sizeMinusOffsetPerThread = - affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, - {offset, loopRange.size}); - size = affine::makeComposedFoldedAffineMin( - rewriter, loc, - AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), - {sizeMinusOffsetPerThread, tileSize}); - } - - // Consider the case where the original loop was `[0, 100)`. - // If number of threads are `7`, the tile size would be computed as - // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) - // - `offset = 0 + 6 * 15 = 105` - // - `tileSize = min(15, 100 - 105) = -5` - // To avoid negative tile sizes, we need to do a further - // `nonNegativeTileSize = affine.max(0, tileSize)`. - // This `max` can be avoided if - // `offset + tileSize * (numThreads - 1) < (ub - lb)` - if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { - AffineMap maxMap = - AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); - size = affine::makeComposedFoldedAffineMax( - rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); - } - - offsets.push_back(offset); - sizes.push_back(size); + for (auto [givenTileSize, loopRange] : + llvm::zip_equal(givenTileSizes, iterationDomain)) { + + // Non-tiled cases, set the offset and size to the + // `loopRange.offset/size`. + if (isZeroInteger(givenTileSize)) { + offsets.push_back(loopRange.offset); + sizes.push_back(loopRange.size); + continue; } - return {offsets, sizes}; - } else { - for (auto [tileSize, loopRange] : - llvm::zip_equal(tileSizes, iterationDomain)) { - - // Non-tiled cases, set the offset and size to the - // `loopRange.offset/size`. - if (isZeroInteger(tileSize)) { - offsets.push_back(loopRange.offset); - sizes.push_back(loopRange.size); - continue; - } - Value iv = ivs[materializedLoopNum++]; - OpFoldResult offset = getAsOpFoldResult(iv); - offsets.push_back(offset); - OpFoldResult size = - getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); - sizes.push_back(size); - } - return {offsets, sizes}; + Value iv = ivs[materializedLoopNum++]; + OpFoldResult offset = getAsOpFoldResult(iv); + offsets.push_back(offset); + OpFoldResult size = + getBoundedTileSize(rewriter, loc, loopRange, offset, givenTileSize); + sizes.push_back(size); } + return {offsets, sizes}; } /// Function to return the bounds of the loops to be generated. static std::tuple, SmallVector, SmallVector> getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef loopRanges, - ArrayRef tileSizes) { + ArrayRef givenTileSizes) { SmallVector lbs, ubs, steps; - for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { + for (auto [loopRange, givenTileSize] : + llvm::zip_equal(loopRanges, givenTileSizes)) { // No loop if the tile size is 0. - if (isZeroInteger(tileSize)) + if (isZeroInteger(givenTileSize)) continue; lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); - steps.push_back(tileSize); + steps.push_back(givenTileSize); } return {lbs, ubs, steps}; } -/// A function that allows returning additional yielded values during +/// Typedef for function that allows returning additional yielded values during /// `yieldTiledValuesAndReplace`. /// - `ivs` induction variable for the loop. /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. @@ -402,6 +339,30 @@ using YieldTiledValuesFn = std::function> &resultOffsets, SmallVector> &resultSizes)>; +/// Typedef for function that implements the body of a tiled loop. +/// - `ivs` induction variable for the loop. +/// - `tileOffsets` represents offsets for the tiled iteration space. +/// - `tileSizes` represents the sizes for the tiled iteraiton space. +/// - `outerDestinationTensors` tensor that holds the result. Is same size +/// as the destination operands of the original operations. +/// - `tiledResults` results of the tiled computation, corresponds to +/// tiles of the original operation computed by the loop body. +/// Should be same size as the `destinationTensors` +/// - `resultOffsets` is of the same size as `tiledResults` and represents +/// the offset to use when writing the corresponding element from +/// `tiledResults` into `destinationTensors`. +/// - `resultOffsets` is of the same size as `tiledResults` and represents +/// the size to use when writing the corresponding element from +/// `tiledResults` into `destinationTensors`. +/// In case the method needs to return `failure()` the method is expected +/// to clean up any inserted operations. +using GenerateTiledBodyFn = std::function tileOffsets, ArrayRef tileSizes, + ValueRange outerDestinationTensors, SmallVector &tiledResults, + SmallVector> &resultOffsets, + SmallVector> &resultSizes)>; + /// Clones the operation and updates the destination if the operation /// implements the `DestinationStyleOpInterface`. static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, @@ -417,26 +378,25 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, /// Generate the tile-loop nest using `scf.for` operation. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. -/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. -/// - `destinationTensors` are the init values to use for the outer most loop. -/// - `yieldTiledValuesFn` is called to generated the loop body of the inner +/// - `givenTileSizes` is the tile sizes to use. Zero represent untiled loops. +/// - `outerDestinationTensors` are the init values to use for the outer most +/// loop. +/// - `tiledBodyFn` is called to generated the loop body of the inner /// most /// loop. -/// - `loops` is an in-out parameter into which the generated loops are -/// populated. -static LogicalResult generateLoopNestUsingForOp( +/// Returns the generated `scf.for` loops on success. +static FailureOr> generateLoopNestUsingForOp( RewriterBase &rewriter, Location loc, ArrayRef loopRanges, - ArrayRef tileSizes, ValueRange destinationTensors, - YieldTiledValuesFn yieldTiledValuesFn, - SmallVector &loops) { + ArrayRef givenTileSizes, ValueRange outerDestinationTensors, + GenerateTiledBodyFn tiledBodyFn) { assert(!loopRanges.empty() && "unexpected empty loop ranges"); - assert(loopRanges.size() == tileSizes.size() && + assert(loopRanges.size() == givenTileSizes.size() && "expected as many tile sizes as loop ranges"); OpBuilder::InsertionGuard guard(rewriter); SmallVector lbs, ubs, steps; std::tie(lbs, ubs, steps) = - getLoopBounds(rewriter, loc, loopRanges, tileSizes); + getLoopBounds(rewriter, loc, loopRanges, givenTileSizes); SmallVector lbVals = getValueOrCreateConstantIndexOp(rewriter, loc, lbs); SmallVector ubVals = @@ -445,33 +405,44 @@ static LogicalResult generateLoopNestUsingForOp( getValueOrCreateConstantIndexOp(rewriter, loc, steps); SmallVector ivs; + SmallVector loops; + ValueRange innerDestinationTensors(outerDestinationTensors); for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { auto loop = - scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors, + scf::ForOp::create(rewriter, loc, lb, ub, step, innerDestinationTensors, [](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, ValueRange /*iterArgs*/) {}); loops.push_back(loop); ivs.push_back(loop.getInductionVar()); rewriter.setInsertionPointToEnd(loop.getBody()); - destinationTensors = loop.getRegionIterArgs(); + innerDestinationTensors = loop.getRegionIterArgs(); } if (loops.empty()) return success(); + // Compute the `offsets` and `sizes` to use for tiling. + SmallVector offsets, sizes; + std::tie(offsets, sizes) = + getTileOffsetAndSizes(rewriter, loc, ivs, loopRanges, givenTileSizes); + SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; - if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, - tiledResults, resultOffsets, resultSizes))) { + if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes, + innerDestinationTensors, tiledResults, resultOffsets, + resultSizes))) { return rewriter.notifyMatchFailure( loc, "failed to generate inner tile loop body"); } - assert(tiledResults.size() == destinationTensors.size() && + if (loops.empty()) + return loops; + + assert(tiledResults.size() == innerDestinationTensors.size() && "Number of results of body should be equal to number of iter args"); // 6. Yield all the results of the tiled operation. SmallVector yieldedValues; for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : - llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, + llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); @@ -490,27 +461,108 @@ static LogicalResult generateLoopNestUsingForOp( cast(outerLoop.getOperation()).getBody()); scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults()); } - return success(); + return loops; +} + +/// Compute the `OpFoldResult`s that represents the multi-dimensional +/// `offset`s and `size`s of the tile of the iteration space that the +/// innermost loop body of the generated tiled loops corresponds to +/// when tiling using `forall` op. This is handle separately due to +/// the special case handling needed for when the tiling is done by +/// specifying number of threads. +static std::tuple, SmallVector> +getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, + ValueRange ivs, + ArrayRef iterationDomain, + ArrayRef givenTileSizes, + ArrayRef numThreads) { + if (numThreads.empty()) { + return getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain, + givenTileSizes); + } + + SmallVector offsets, sizes; + int materializedLoopNum = 0; + + AffineExpr d0, d1, s0, s1; + AffineExpr offsetExpr, residualTileSizeExpr; + bindDims(rewriter.getContext(), d0, d1); + bindSymbols(rewriter.getContext(), s0, s1); + offsetExpr = d0 + d1 * s0; + residualTileSizeExpr = s1 - (d0 + d1 * s0); + + for (auto [index, nt, givenTileSize, loopRange] : + llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) { + + // Non-tiled cases, set the offset and size to the + // `loopRange.offset/size`. + if (isZeroInteger(nt)) { + offsets.push_back(loopRange.offset); + sizes.push_back(loopRange.size); + continue; + } + + Value iv = ivs[materializedLoopNum++]; + OpFoldResult offset = affine::makeComposedFoldedAffineApply( + rewriter, loc, offsetExpr, + ArrayRef{loopRange.offset, iv, givenTileSize}); + OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( + rewriter, loc, residualTileSizeExpr, + {loopRange.offset, nt, givenTileSize, loopRange.size}); + + OpFoldResult size = givenTileSize; + if (!isZeroInteger(residualTileSize)) { + OpFoldResult sizeMinusOffsetPerThread = + affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, + {offset, loopRange.size}); + size = affine::makeComposedFoldedAffineMin( + rewriter, loc, + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + {sizeMinusOffsetPerThread, givenTileSize}); + } + + // Consider the case where the original loop was `[0, 100)`. + // If number of threads are `7`, the tile size would be computed as + // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) + // - `offset = 0 + 6 * 15 = 105` + // - `tileSize = min(15, 100 - 105) = -5` + // To avoid negative tile sizes, we need to do a further + // `nonNegativeTileSize = affine.max(0, tileSize)`. + // This `max` can be avoided if + // `offset + tileSize * (numThreads - 1) < (ub - lb)` + if (!canOmitTileOffsetInBoundsCheck(givenTileSize, nt, loopRange.size)) { + AffineMap maxMap = + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + size = affine::makeComposedFoldedAffineMax( + rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); + } + + offsets.push_back(offset); + sizes.push_back(size); + } + return {offsets, sizes}; } /// Generate the tile-loop nest using `scf.forall` operation. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. -/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. -/// - `destinationTensors` are the init values to use for the outer most loop. +/// - `giventileSizes` is the tile sizes to use. Zero represent untiled loops. +/// - `outerDestinationTensors` are the init values to use for the loop. /// - `mappingVector` is the mapping attributes to use for loop construction. /// Can be empty. -/// - `yieldTiledValuesFn` is called to generated the loop body of the inner +/// - `tiledBodyFn` is called to generated the loop body of the inner /// most /// loop. -/// - `loops` is an in-out parameter into which the generated loops are -/// populated. -static LogicalResult generateLoopNestUsingForallOp( - RewriterBase &rewriter, Location loc, ArrayRef loopRanges, - ArrayRef tileSizes, ArrayRef numThreads, - ArrayRef mappingVector, ValueRange destinationTensors, - YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { +/// Returns the generated `scf.forall` loop on success. +static FailureOr> +generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, + ArrayRef loopRanges, + ArrayRef givenTileSizes, + ArrayRef numThreads, + ArrayRef mappingVector, + ValueRange outerDestinationTensors, + GenerateTiledBodyFn tiledBodyFn) { assert(!loopRanges.empty() && "unexpected empty loop ranges"); - assert(loopRanges.size() == tileSizes.size() && + assert(loopRanges.size() == givenTileSizes.size() && "expected as many tile sizes as loop ranges"); OpBuilder::InsertionGuard guard(rewriter); @@ -521,6 +573,7 @@ static LogicalResult generateLoopNestUsingForallOp( scf::ForallOp forallOp; bool useNumThreads = !numThreads.empty(); + SmallVector loops; if (useNumThreads) { // Prune the zero numthreads. SmallVector nonZeroNumThreads; @@ -530,29 +583,35 @@ static LogicalResult generateLoopNestUsingForallOp( nonZeroNumThreads.push_back(nt); } forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads, - destinationTensors, mappingAttr); + outerDestinationTensors, mappingAttr); } else { SmallVector lbs, ubs, steps; std::tie(lbs, ubs, steps) = - getLoopBounds(rewriter, loc, loopRanges, tileSizes); + getLoopBounds(rewriter, loc, loopRanges, givenTileSizes); forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps, - destinationTensors, mappingAttr); + outerDestinationTensors, mappingAttr); } loops.push_back(forallOp); rewriter.setInsertionPoint(forallOp.getTerminator()); - destinationTensors = forallOp.getRegionOutArgs(); + ValueRange innerDestinationTensors = forallOp.getRegionOutArgs(); + SmallVector ivs = forallOp.getInductionVars(); + + // Compute the `offsets` and `sizes` to use for tiling. + SmallVector offsets, sizes; + std::tie(offsets, sizes) = getTileOffsetAndSizesWithForAllOp( + rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads); SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; - if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), - destinationTensors, tiledResults, resultOffsets, + if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes, + innerDestinationTensors, tiledResults, resultOffsets, resultSizes))) return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : - llvm::zip_equal(tiledResults, destinationTensors, resultOffsets, + llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); @@ -561,41 +620,105 @@ static LogicalResult generateLoopNestUsingForallOp( destinationTensor, resultOffset, resultSize, resultStride); } - return success(); + return loops; +} + +/// Generate the tile-loop nest using custom loop operation. +/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. +/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. +/// - `destinationTensors` are the init values to use for the outer most loop. +/// - `mappingVector` is the mapping attributes to use for loop construction. +/// Can be empty. +/// - `tiledBodyFn` is called to generated the loop body of the inner +/// most +/// loop. +/// Returns the generated `scf.forall` loop on success. +static FailureOr> +generateLoopNestUsingCustomOp( + RewriterBase &rewriter, Location loc, ArrayRef loopRanges, + ArrayRef givenTileSizes, ValueRange outerDestinationTensors, + const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, + const scf::SCFTilingOptions::GenerateLoopTerminatorFn + &generateLoopTerminatorFn, + GenerateTiledBodyFn tiledBodyFn) { + assert(!loopRanges.empty() && "unexpected empty loop ranges"); + assert(loopRanges.size() == givenTileSizes.size() && + "expected as many tile sizes as loop ranges"); + assert(generateLoopHeaderFn && generateLoopTerminatorFn && + "expected loop header/terminator generation function"); + OpBuilder::InsertionGuard guard(rewriter); + + FailureOr loopHeaderInfo = + generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes, + outerDestinationTensors); + if (failed(loopHeaderInfo)) { + return failure(); + } + + SmallVector ivs; + SmallVector tiledResults; + SmallVector> resultOffsets, resultSizes; + if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset, + loopHeaderInfo->tileSizes, + loopHeaderInfo->destinationTensors, tiledResults, + resultOffsets, resultSizes))) { + return failure(); + } + + if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults, + resultOffsets, resultSizes, + loopHeaderInfo->destinationTensors))) { + return failure(); + } + + return loopHeaderInfo->loops; } /// Generate the tile-loop nest using the loop construct specifed in `options`. /// - `options`: Tiling options specified. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. -/// - `destinationTensors` are the init values to use for the outer most loop. +/// - `outerDestinationTensors` are the init values to use for the outer most +/// loop. /// - `yieldTiledValuesFn` is called to generated the loop body of the inner /// most /// loop. -/// - `loops` is an in-out parameter into which the generated loops are -/// populated. -static LogicalResult generateLoopNest( - RewriterBase &rewriter, Location loc, - scf::SCFTilingOptions::LoopType loopType, ArrayRef loopRanges, - ArrayRef tileSizes, ArrayRef numThreads, - ValueRange destinationTensors, ArrayRef mappingVector, - YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { +/// Returns the generated loops on success. +static FailureOr> generateLoopNest( + RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, + ArrayRef loopRanges, ArrayRef givenTileSizes, + ArrayRef numThreads, ValueRange destinationTensors, + GenerateTiledBodyFn tiledBodyFn) { // If the tile sizes are all zero, no loops are generated. Just call the // callback function to handle untiled case. - if (llvm::all_of(tileSizes, isZeroInteger)) { + if (llvm::all_of(givenTileSizes, isZeroInteger)) { SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; - return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, - tiledResults, resultOffsets, resultSizes); + auto tileOffsets = + llvm::map_to_vector(loopRanges, [](Range r) { return r.offset; }); + auto tileSizes = + llvm::map_to_vector(loopRanges, [](Range r) { return r.size; }); + if (failed(tiledBodyFn(rewriter, loc, ValueRange{}, tileOffsets, tileSizes, + destinationTensors, tiledResults, resultOffsets, + resultSizes))) { + return failure(); + } + return SmallVector{}; } - if (loopType == scf::SCFTilingOptions::LoopType::ForOp) { - return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, - destinationTensors, tiledBodyFn, loops); + if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { + return generateLoopNestUsingForOp(rewriter, loc, loopRanges, givenTileSizes, + destinationTensors, tiledBodyFn); } - if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) { + if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { return generateLoopNestUsingForallOp( - rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector, - destinationTensors, tiledBodyFn, loops); + rewriter, loc, loopRanges, givenTileSizes, numThreads, + options.mappingVector, destinationTensors, tiledBodyFn); + } + if (options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) { + return generateLoopNestUsingCustomOp( + rewriter, loc, loopRanges, givenTileSizes, destinationTensors, + options.generateLoopHeaderFn, options.generateLoopTerminatorFn, + tiledBodyFn); } return rewriter.notifyMatchFailure(loc, "unhandled loop type"); } @@ -603,7 +726,7 @@ static LogicalResult generateLoopNest( static FailureOr> createInitialTensorsForTiling( RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef iterationDomain, - ArrayRef numThreads, ArrayRef tileSizes, + ArrayRef numThreads, ArrayRef givenTileSizes, const SetVector &reductionDims) { SmallVector initTensors; Location loc = op->getLoc(); @@ -625,7 +748,7 @@ static FailureOr> createInitialTensorsForTiling( AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2)); AffineExpr divExpr = s0.ceilDiv(s1); for (auto [index, domain, tileSize] : - llvm::enumerate(iterationDomain, tileSizes)) { + llvm::enumerate(iterationDomain, givenTileSizes)) { if (!numThreads.empty()) { // Untiled case. if (isConstantIntValue(numThreads[index], 0)) { @@ -671,7 +794,7 @@ static SmallVector getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef numThreads, - ArrayRef tileSizes, + ArrayRef givenTileSizes, const SetVector &reductionDims) { SmallVector splitReductionIvs; splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0)); @@ -688,7 +811,7 @@ getSplitReductionIvs(RewriterBase &rewriter, Location loc, } splitReductionIvs[index] = affine::makeComposedFoldedAffineApply( rewriter, loc, divExpr, - ArrayRef{ivs[ivIndex++], tileSizes[reductionDim]}); + ArrayRef{ivs[ivIndex++], givenTileSizes[reductionDim]}); } } return splitReductionIvs; @@ -700,7 +823,7 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef offsets, ArrayRef sizes, ValueRange ivs, ArrayRef numThreads, - ArrayRef tileSizes, + ArrayRef givenTileSizes, const SetVector &reductionDims) { if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getTiledImplementation(rewriter, offsets, sizes); @@ -716,7 +839,7 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op, SmallVector splitReductionIvs = getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, - numThreads, tileSizes, reductionDims); + numThreads, givenTileSizes, reductionDims); return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy, regionIterArg, offsets, sizes, reductionDims, splitReductionIvs); @@ -727,7 +850,8 @@ static LogicalResult getResultTilePosition( int64_t index, Value tiledResult, TilingInterface op, ArrayRef offsets, ArrayRef sizes, ValueRange ivs, ArrayRef numThreads, - ArrayRef tileSizes, const SetVector &reductionDims, + ArrayRef givenTileSizes, + const SetVector &reductionDims, SmallVector &resultOffset, SmallVector &resultSize) { @@ -743,7 +867,7 @@ static LogicalResult getResultTilePosition( } SmallVector splitReductionIvs = getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, - numThreads, tileSizes, reductionDims); + numThreads, givenTileSizes, reductionDims); return redOp.getPartialResultTilePosition( rewriter, index, reductionStrategy, offsets, sizes, reductionDims, splitReductionIvs, resultOffset, resultSize); @@ -998,20 +1122,20 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, SmallVector iterationDomain = op.getIterationDomain(rewriter); // 2. Materialize the tile sizes and/or number of threads; - SmallVector tileSizes, numThreads; - std::tie(tileSizes, numThreads) = + SmallVector givenTileSizes, numThreads; + std::tie(givenTileSizes, numThreads) = getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); // Check if it is safe to tile. This is hold over from previous iterations // of tile to for-all. Consider dropping it. if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy, - tileSizes, numThreads))) { + givenTileSizes, numThreads))) { return failure(); } // Get the reduction dims SetVector reductionDims = - getSanitizedReductionDims(tileSizes, options); + getSanitizedReductionDims(givenTileSizes, options); // 3. If there is an interchange specified, permute the iteration domain and // the tile sizes. @@ -1023,7 +1147,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, "expected interchange vector to be a permutation"); applyPermutationToVector(iterationDomain, interchangeVector); - applyPermutationToVector(tileSizes, interchangeVector); + applyPermutationToVector(givenTileSizes, interchangeVector); if (!numThreads.empty()) applyPermutationToVector(numThreads, interchangeVector); } @@ -1031,24 +1155,21 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, FailureOr tilingResult; // 4. Define the lambda function used later to generate the body of the // innermost tiled loop. - YieldTiledValuesFn innerYieldTiledValuesFn = + GenerateTiledBodyFn innerYieldTiledValuesFn = [&](RewriterBase &rewriter, Location loc, ValueRange ivs, + ArrayRef tileOffsets, ArrayRef tileSizes, ValueRange regionIterArgs, SmallVector &tiledResults, SmallVector> &resultOffsets, SmallVector> &resultSizes) -> LogicalResult { - // 4a. Compute the `offsets` and `sizes` to use for tiling. - SmallVector offsets, sizes; - std::tie(offsets, sizes) = getTileOffsetAndSizes( - rewriter, loc, options.reductionStrategy, ivs, iterationDomain, - tileSizes, numThreads, reductionDims); - // 4b. If interchange was provided, apply inverse of the interchange // to get back the offsets/sizes in the order to be specified. + SmallVector tileOffsetsVec = llvm::to_vector(tileOffsets); + SmallVector tileSizesVec = llvm::to_vector(tileSizes); if (!interchangeVector.empty()) { auto inversePermutation = invertPermutationVector(interchangeVector); - applyPermutationToVector(offsets, inversePermutation); - applyPermutationToVector(sizes, inversePermutation); + applyPermutationToVector(tileOffsetsVec, inversePermutation); + applyPermutationToVector(tileSizesVec, inversePermutation); } // 5. Generate the tiled implementation within the inner most loop. @@ -1060,7 +1181,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 5b. Early return cloned op if tiling is not happening. We can not // return the original op because it could lead to `rewriter.replaceOp(op, // op->getResults())` and users would get crash. - if (llvm::all_of(tileSizes, isZeroInteger)) { + if (llvm::all_of(givenTileSizes, isZeroInteger)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), @@ -1069,9 +1190,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, } // 5c. Tile the cloned operation. - tilingResult = getTiledImplementation( - rewriter, clonedOp, options.reductionStrategy, regionIterArgs, offsets, - sizes, ivs, numThreads, tileSizes, reductionDims); + tilingResult = + getTiledImplementation(rewriter, clonedOp, options.reductionStrategy, + regionIterArgs, tileOffsetsVec, tileSizesVec, + ivs, numThreads, givenTileSizes, reductionDims); if (failed(tilingResult)) { rewriter.eraseOp(clonedOp); return op.emitOpError("faild to tile operation"); @@ -1088,8 +1210,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, SmallVector resultOffset, resultSize; if (failed(getResultTilePosition( rewriter, options.reductionStrategy, index, tiledValue, op, - offsets, sizes, ivs, numThreads, tileSizes, reductionDims, - resultOffset, resultSize))) { + tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes, + reductionDims, resultOffset, resultSize))) { for (auto op : tilingResult->tiledOps) { rewriter.eraseOp(op); } @@ -1106,7 +1228,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 6. Find the destination tensors to use for the operation. FailureOr> maybeInits = createInitialTensorsForTiling( rewriter, op, options.reductionStrategy, iterationDomain, numThreads, - tileSizes, reductionDims); + givenTileSizes, reductionDims); if (failed(maybeInits)) { return rewriter.notifyMatchFailure( op, "unable to create initial tensors for tiling"); @@ -1115,13 +1237,16 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 7. Generate the tiled loops nest using the callback defined above. SmallVector loops; - if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType, - iterationDomain, tileSizes, numThreads, - initTensors, options.mappingVector, - innerYieldTiledValuesFn, loops))) - return op.emitOpError("failed to generate tiling loops"); - assert(succeeded(tilingResult) && - "expected tiling result to be computed after loop generation"); + { + FailureOr> loopsOr = generateLoopNest( + rewriter, op.getLoc(), options, iterationDomain, givenTileSizes, + numThreads, initTensors, innerYieldTiledValuesFn); + if (failed(loopsOr)) + return op.emitOpError("failed to generate tiling loops"); + assert(succeeded(tilingResult) && + "expected tiling result to be computed after loop generation"); + std::swap(loops, loopsOr.value()); + } if (loops.empty()) { // If loops are empty, the tiled op is used as the replacement for the diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir new file mode 100644 index 0000000000000..d335e9c3fb5d0 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --mlir-print-local-scope %s | FileCheck %s + +module { + func.func @generic_parallel(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %empty = tensor.empty(%d0, %d1) : tensor + %generic = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) outs(%empty : tensor) { + ^bb(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + return %generic : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops {["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %tiled_op, %loop = transform.test.tile_using_custom_loop %op tile_sizes = [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @generic_parallel +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor +// CHECK-DAG: %[[NITERS0:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 10)>()[%[[D0]]] +// CHECK-DAG: %[[NITERS1:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 20)>()[%[[D1]]] +// CHECK-DAG: %[[NITERS:.+]] = affine.apply affine_map<()[s0, s1] -> ((s0 ceildiv 10) * (s1 ceildiv 20))>()[%[[D0]], %[[D1]]] +// CHECK: %[[FOR:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[NITERS]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[EMPTY]]) +// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IV]] into (%[[NITERS0]], %[[NITERS1]]) +// CHECK-DAG: %[[SIZE0:.+]] = affine.min affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>(%[[DELINEARIZE]]#0)[%[[D0]]] +// CHECK-DAG: %[[SIZE1:.+]] = affine.min affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>(%[[DELINEARIZE]]#1)[%[[D1]]] +// CHECK-DAG: %[[OFFSET0:.+]] = affine.apply affine_map<(d0) -> (d0 * 10)>(%[[DELINEARIZE]]#0) +// CHECK-DAG: %[[OFFSET1:.+]] = affine.apply affine_map<(d0) -> (d0 * 20)>(%[[DELINEARIZE]]#1) +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1] +// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[OFFSET1]]] [%[[SIZE1]]] [1] +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : +// CHECK-SAME: outs(%[[INIT_SLICE]] : +// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[INIT]] +// CHECK-SAME: [%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1] +// CHECK: scf.yield %[[INSERT_SLICE]] +// CHECK: return %[[FOR]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 3d24d4ecc4d0d..7981c72c2f2c8 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -468,6 +469,158 @@ transform::TestTileAndFuseOuterParallelPartialReductionOp::apply( : DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestTileAndFuseOuterParallelPartialReduction +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( + TransformRewriter &transformRewriter, TransformResults &transformResults, + TransformState &state) { + auto target = + dyn_cast(*state.getPayloadOps(getRootOp()).begin()); + if (!target) { + emitOpError("expected root operation to implement `TilingInterface`"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + OpFoldResult oneOfr = transformRewriter.getIndexAttr(1); + + scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn = + [&](RewriterBase &rewriter, Location loc, ArrayRef loopRanges, + ArrayRef givenTileSizes, + ValueRange outerDestinationTensors) + -> FailureOr { + // Check that the strides are all 1 (to make it easier in the test). + if (llvm::any_of(loopRanges, [](Range r) { + return !isConstantIntValue(r.stride, 1); + })) { + return emitOpError("unable to handle loop ranges with strides != 1"); + } + // Check number of tile sizes is equal to loop dimensions. + if (loopRanges.size() != givenTileSizes.size()) { + return emitOpError("expected number of tile sizes to be same as the " + "number of loops in the operation"); + } + // For testing disallow any of the tile sizes being 0. + if (llvm::any_of(givenTileSizes, isZeroInteger)) { + return emitOpError("unhandled case of zero tile size"); + } + // For testing, only handle tensor tiling. + if (outerDestinationTensors.empty()) { + return emitOpError("expected destination tensors"); + } + + // Compute the number of iterations for each of the loops. + AffineExpr s0, s1, s2; + bindSymbols(rewriter.getContext(), s0, s1, s2); + AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize + + SmallVector allNumIters; + allNumIters.reserve(loopRanges.size()); + for (auto [loopRange, tileSize] : + llvm::zip_equal(loopRanges, givenTileSizes)) { + OpFoldResult numIters = affine::makeComposedFoldedAffineApply( + rewriter, loc, numItersExpr, + {loopRange.offset, loopRange.size, tileSize}); + allNumIters.push_back(numIters); + } + if (allNumIters.empty()) { + return emitOpError("invalid empty tile sizes and loop ranges"); + } + + AffineExpr mulExpr = s0 * s1; + OpFoldResult cumulative = oneOfr; + for (auto numIters : allNumIters) { + cumulative = affine::makeComposedFoldedAffineApply( + rewriter, loc, mulExpr, {cumulative, numIters}); + } + + Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cumulative); + + SmallVector offsets; + SmallVector sizes; + SmallVector innerDestinationTensors; + offsets.reserve(loopRanges.size()); + sizes.reserve(loopRanges.size()); + + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize + AffineMap minMap = + AffineMap::get(1, 2, {s0 - d0, s1}, + rewriter.getContext()); // min(ub - offset, tileSize) + auto forOp = scf::ForOp::create( + rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors, + [&](OpBuilder &b, Location bodyLoc, Value linearizedIv, + ValueRange destinations) { + auto delinearizeOp = affine::AffineDelinearizeIndexOp::create( + b, bodyLoc, linearizedIv, allNumIters); + for (auto [normalizedIv, range, tileSize] : llvm::zip_equal( + delinearizeOp.getResults(), loopRanges, givenTileSizes)) { + + OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv); + OpFoldResult offset = affine::makeComposedFoldedAffineApply( + b, bodyLoc, offsetExpr, + {normalizedIvOfr, range.offset, tileSize}); + offsets.push_back(offset); + + OpFoldResult size = affine::makeComposedFoldedAffineMin( + b, bodyLoc, minMap, {offset, range.size, tileSize}); + sizes.push_back(size); + } + innerDestinationTensors = llvm::to_vector(destinations); + }); + rewriter.setInsertionPointToEnd(forOp.getBody()); + return scf::SCFTilingOptions::CustomLoopHeaderInfo{ + {cast(forOp.getOperation())}, + offsets, + sizes, + innerDestinationTensors}; + }; + + scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn = + [&](RewriterBase &rewriter, Location loc, ValueRange tiledResults, + ArrayRef> resultOffsets, + ArrayRef> resultSizes, + ValueRange destinationTensors) -> LogicalResult { + SmallVector yieldValues; + yieldValues.reserve(destinationTensors.size()); + for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal( + tiledResults, resultOffsets, resultSizes, destinationTensors)) { + SmallVector strides(offsets.size(), oneOfr); + Value insertedVal = tensor::InsertSliceOp::create( + rewriter, loc, tiledResult, destination, offsets, sizes, strides); + yieldValues.push_back(insertedVal); + } + scf::YieldOp::create(rewriter, loc, yieldValues); + return success(); + }; + + scf::SCFTilingOptions tilingOptions; + SmallVector staticTileSizes = + extractFromIntegerArrayAttr(getTileSizes()); + SmallVector tileSizes = + getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes); + tilingOptions.setTileSizes(tileSizes) + .setLoopType(scf::SCFTilingOptions::LoopType::CustomOp) + .setCustomLoopGenerationFns(loopHeaderFn, terminatorFn); + + OpBuilder::InsertionGuard g(transformRewriter); + transformRewriter.setInsertionPoint(target); + FailureOr tiledResults = + scf::tileUsingSCF(transformRewriter, target, tilingOptions); + if (failed(tiledResults)) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + transformRewriter.replaceOp(target, tiledResults->replacements); + transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps); + transformResults.set(getOperation()->getResult(1), tiledResults->loops); + + return DiagnosedSilenceableFailure::success(); +} + #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 58ccd30bb99a2..694c4229eef62 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -150,4 +150,27 @@ def TestTileAndFuseOuterParallelPartialReductionOp : Op< }]; } +def TestTileUsingCustomLoopOp : Op< + Transform_Dialect, "test.tile_using_custom_loop", + [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Test Transform op to tile an operation using custom loops. + + The test just folds all the loops and into a single loop and then + delinearizes the indices. + }]; + + let arguments = (ins TransformHandleTypeInterface:$root_op, + DefaultValuedAttr:$tile_sizes); + let results = (outs TransformHandleTypeInterface:$tiled_ops, + Variadic:$loops); + + let assemblyFormat = [{ + $root_op `tile_sizes` `=` $tile_sizes + attr-dict `:` functional-type(operands, results) + }]; +} + #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS