diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 7b43aa43c7517..3205da6e448fc 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -313,19 +313,23 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options); -/// Fuse the consumer of the source of `candidateSliceOp` by computing the -/// required slice of the consumer in-place. Note that the method -/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer -/// value but does not delete the slice operation. +/// Fuse the consumer `candidateSlices` by computing the required slice of the +/// consumer in-place. All the entries of `candidateSlices` are expected to map +/// to the same consumer. The method returns an error if the consumer cannot be +/// tiled in a manner that is consistent for all the passed slices. Note that +/// the method replaces the uses of `candidateSlices` with the tiled and fused +/// consumer value but does not delete the slice operations. struct SCFFuseConsumerOfSliceResult { - OpOperand *origConsumerOperand; // Original untiled consumer's operand. - OpOperand - *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand. + // Original untiled consumer operands. + SmallVector origConsumerOperands; + // Tiled and fused consumer operands. + SmallVector tiledAndFusedConsumerOperands; SmallVector tiledOps; }; FailureOr -tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, - MutableArrayRef loops); +tileAndFuseConsumerOfSlices(RewriterBase &rewriter, + ArrayRef candidateSlices, + MutableArrayRef loops); /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 18981337742eb..87deef9ca7466 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -31,12 +31,16 @@ namespace tensor { FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); -/// Method to swap an `tensor.insert_slice` with its consumer when the -/// consumer implements the `TilingInterface`. +/// Method to swap `tensor.insert_slice`s with their consumers when the +/// consumer implements the `TilingInterface`. The size of `sliceOps` and +/// `consumerOperands` is expected to be the same. Every entry in +/// `consumerOperands` represents a use of the the corresponding +/// entry in `sliceOps` in the consumer. All entries of `consumerOperands` is +/// expected to be uses in the same consumer. FailureOr -replaceInsertSliceWithTiledConsumer(OpBuilder &builder, - OffsetSizeAndStrideOpInterface sliceOp, - OpOperand &consumerOp); +replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, + ArrayRef sliceOps, + ArrayRef consumerOperands); //===----------------------------------------------------------------------===// // Populate functions. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 31f54413a5ff0..663c256c848df 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion { using PointerUnion::PointerUnion; public: - void dump() const { llvm::errs() << *this << "\n"; } + LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; } MLIRContext *getContext() const { PointerUnion pu = *this; diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index 0de37338c95e4..0c0fc88aec95a 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> { InterfaceMethod< /*desc=*/[{ Method to generate the tiled implementation of an operation that uses - exactly a tile of the given operand. + the exact tiles of the given operands. This method is required to allow operations to be "tiled and fused" - with an (already tiled) producer. Given a tile of the producer, this - method generates the tile of the consumer that uses exactly this - produced tile. In some sense it is the "reverse" of + with an (already tiled) producer. Given tiles of the producer, this + method generates the tile of the consumer that uses exactly these + produced tiles. In some sense it is the "reverse" of `generateResultTileValue`. - - `operandNumber` is the result of the producer used by the consumer. - - `offsets` is the offset of the slice of the producer result used by - the tiled implementation of the consumer. - - `sizes` is the size of the slice of the producer result used by the + - `operandNumbers` is the list of operands whose tiles are "producers". + - `allOffsets` is the offset of the slice of the producer used by the + tiled implementation of the consumer. + - `allSizes` is the size of the slice of the producer used by the consumer. - If it is illegal to fuse with a producer along the given operand for + If it is illegal to fuse with a producer along the given operand tiles for an operation, the implementation should return a failure. }], /*retType=*/"::mlir::FailureOr<::mlir::TilingResult>", - /*methodName=*/"getTiledImplementationFromOperandTile", + /*methodName=*/"getTiledImplementationFromOperandTiles", /*args=*/(ins "::mlir::OpBuilder &":$b, - "unsigned":$operandNumber, - "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, - "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes), + "::mlir::ArrayRef":$operandNumbers, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes), /*methodBody=*/"", /*defaultImplementation=*/[{ return failure(); @@ -235,16 +235,17 @@ def TilingInterface : OpInterface<"TilingInterface"> { tile of the operand. This method is required to allow operations to be "tiled and fused" - with an (already tiled) producer. Given a tile of an operand, - returns the tile of the iteration space that uses this tile. - - `operandNumber` is the result of the producer used by the consumer. - - `offsets` is the offset of the slice of the producer result used by + with an (already tiled) producer. Given tiles of operands, + returns the tile of the iteration space that uses these tiles. + - `operandNumbers` is the list of operands whose tiles are "produced" + by the producer(s). + - `allOffsets` is the offset of the slice of the producers used by the tiled implementation of the consumer. - - `sizes` is the size of the slice of the producer result used by the + - `allSizes` is the size of the slice of the producers used by the consumer. - If it is illegal to fuse with a producer along the given operand for - an operation, or if this mapping cannot be computed, the - implementation should return a failure. + If it is illegal to fuse with the producer slices for an operation, + or if this mapping cannot be computed, the implementation should + return a failure. Note that unlike the "tile consumer and fuse producer" case, the "tile producer and fuse consumer" requires an additional method to get @@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> { transformation. It does not provide guarantees on whether such a transformation is profitable. - For most cases `getTiledImplementationFromOperandTile` could be a - implemented using `getIterationDomainTileFromOperandTile` + + For most cases `getTiledImplementationFromOperandTiles` could be a + implemented using `getIterationDomainTileFromOperandTiles` + `getTiledImplementation` methods. }], /*retType=*/"::llvm::LogicalResult", - /*methodName=*/"getIterationDomainTileFromOperandTile", + /*methodName=*/"getIterationDomainTileFromOperandTiles", /*args=*/(ins "::mlir::OpBuilder &":$b, - "unsigned":$operandNumber, - "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets, - "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes, + "::mlir::ArrayRef":$operandNumbers, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes, "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets, "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes), /*methodBody=*/"", diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 19d484a3bb701..513cecef29b61 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -22,8 +22,11 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "llvm/Support/Debug.h" #include +#define DEBUG_TYPE "linalg-tiling-interface-impl" + using namespace mlir; using namespace mlir::linalg; @@ -148,55 +151,82 @@ struct LinalgOpTilingInterface /// Utility to fetch the offsets and sizes when applied as per the indexing /// map of the linalg op. This helps in fusing the linalg op as a consumer of /// a given slice op. - void - getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, - ArrayRef offsets, - ArrayRef sizes, - SmallVectorImpl &mappedOffsets, - SmallVectorImpl &mappedSizes) const { - unsigned numLoops = linalgOp.getNumLoops(); - auto tilingInterfaceOp = cast(linalgOp.getOperation()); - mappedOffsets.resize(numLoops); - mappedSizes.resize(numLoops); - if (!indexingMap.isPermutation()) { - SmallVector iterationDomain = - tilingInterfaceOp.getIterationDomain(b); - for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { - mappedOffsets[index] = value.offset; - mappedSizes[index] = value.size; + static LogicalResult + getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, + ArrayRef indexingMaps, + ArrayRef> allOffsets, + ArrayRef> allSizes, + SmallVectorImpl &mappedOffsetsVec, + SmallVectorImpl &mappedSizesVec) { + DenseMap mappedOffsets, mappedSizes; + + for (auto [indexingMap, offsets, sizes] : + llvm::zip_equal(indexingMaps, allOffsets, allSizes)) { + for (auto [resultExpr, offset, size] : + llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) { + auto dimExpr = dyn_cast(resultExpr); + if (!dimExpr) + continue; + unsigned position = dimExpr.getPosition(); + auto it = mappedOffsets.find(position); + if (it != mappedOffsets.end()) { + OpFoldResult seenOffset = it->second; + OpFoldResult seenSize = mappedSizes.lookup(position); + if (seenOffset != offset || seenSize != size) { + LLVM_DEBUG({ + llvm::dbgs() << "inconsistent iteration space mapping from " + "offsets/sizes of operands/results"; + }); + return failure(); + } + } else { + mappedOffsets[position] = offset; + mappedSizes[position] = size; + } } } - for (const auto &&[index, value] : - llvm::enumerate(indexingMap.getResults())) { - unsigned dimPosition = cast(value).getPosition(); - mappedOffsets[dimPosition] = offsets[index]; - mappedSizes[dimPosition] = sizes[index]; + + // Aggregate from the given operand offsets and sizes, or default to + // iteration space values. + SmallVector iterationDomain = + cast(linalgOp.getOperation()).getIterationDomain(b); + mappedOffsetsVec.resize(iterationDomain.size()); + mappedSizesVec.resize(iterationDomain.size()); + for (auto [index, domain] : llvm::enumerate(iterationDomain)) { + auto it = mappedOffsets.find(index); + if (it != mappedOffsets.end()) { + mappedOffsetsVec[index] = it->second; + mappedSizesVec[index] = mappedSizes.lookup(index); + continue; + } + mappedOffsetsVec[index] = domain.offset; + mappedSizesVec[index] = domain.size; } + return success(); } /// Method to return the position of the result tile computed by the tiled /// operation. - LogicalResult getIterationDomainTileFromOperandTile( - Operation *op, OpBuilder &b, unsigned operandNumber, - ArrayRef offsets, ArrayRef sizes, + LogicalResult getIterationDomainTileFromOperandTiles( + Operation *op, OpBuilder &b, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes, SmallVectorImpl &iterDomainOffsets, SmallVectorImpl &iterDomainSizes) const { auto linalgOp = cast(op); - // Check that the indexing map used for the operand is a projected - // permutation. This could be relaxed with a more general approach that can - // map the offsets and sizes from the operand to iteration space tiles - // (filling in full extent for dimensions not used to access the result). - AffineMap indexingMap = - linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); - if (!indexingMap.isProjectedPermutation()) { - return op->emitError() - << "unhandled get iter domain position when operand is not " - "accessed using a permuted projection"; + std::optional> iterationSpaceOffsets, + iterationSpaceSizes; + SmallVector indexingMaps = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { + OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets, + allSizes, iterDomainOffsets, + iterDomainSizes))) { + return failure(); } - - getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, - iterDomainOffsets, iterDomainSizes); return success(); } @@ -247,8 +277,13 @@ struct LinalgOpTilingInterface "accessed using a permuted projection"); } - getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, - iterDomainOffsets, iterDomainSizes); + SmallVector allOffsets = llvm::to_vector(offsets); + SmallVector allSizes = llvm::to_vector(sizes); + auto status = + getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets}, + {allSizes}, iterDomainOffsets, iterDomainSizes); + (void)status; + assert(succeeded(status) && "unexpected error in offset calculation"); return success(); } @@ -279,12 +314,13 @@ struct LinalgOpTilingInterface /// Method to generate the tiled implementation of an operation from the tile /// of the operand. - FailureOr getTiledImplementationFromOperandTile( - Operation *op, OpBuilder &b, unsigned operandNumber, - ArrayRef offsets, ArrayRef sizes) const { + FailureOr getTiledImplementationFromOperandTiles( + Operation *op, OpBuilder &b, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes) const { SmallVector mappedOffsets, mappedSizes; - if (failed(getIterationDomainTileFromOperandTile( - op, b, operandNumber, offsets, sizes, mappedOffsets, + if (failed(getIterationDomainTileFromOperandTiles( + op, b, operandNumbers, allOffsets, allSizes, mappedOffsets, mappedSizes))) { return failure(); } @@ -837,13 +873,20 @@ struct PackOpTiling /// Method to return the position of iteration domain tile computed by the /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and /// `resultSizes` only cover outer dimensions. - LogicalResult getIterationDomainTileFromOperandTile( - Operation *op, OpBuilder &b, unsigned operandNumber, - ArrayRef offsets, ArrayRef sizes, + LogicalResult getIterationDomainTileFromOperandTiles( + Operation *op, OpBuilder &b, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes, SmallVectorImpl &resultOffsets, SmallVectorImpl &resultSizes) const { - if (operandNumber != 0) + if (operandNumbers.size() != 1 || operandNumbers[0] != 0) { + LLVM_DEBUG( + { llvm::dbgs() << "unsupported operands for consumer fusion"; }); return failure(); + } + + ArrayRef offsets(allOffsets[0]); + ArrayRef sizes(allSizes[0]); auto packOp = cast(op); // It is not trivial to infer dest tile from source tile if `packOp` has @@ -904,11 +947,18 @@ struct PackOpTiling } /// Method to return the tiled implementation of tensor.pack as a consumer. - FailureOr getTiledImplementationFromOperandTile( - Operation *op, OpBuilder &b, unsigned operandNumber, - ArrayRef offsets, ArrayRef sizes) const { - if (operandNumber != 0) + FailureOr getTiledImplementationFromOperandTiles( + Operation *op, OpBuilder &b, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes) const { + if (operandNumbers.size() != 1 || operandNumbers[0] != 0) { + LLVM_DEBUG( + { llvm ::dbgs() << "unhandled operands for consumer fusion"; }); return failure(); + } + + ArrayRef offsets(allOffsets[0]); + ArrayRef sizes(allSizes[0]); auto packOp = cast(op); Location loc = packOp.getLoc(); @@ -923,8 +973,8 @@ struct PackOpTiling tiledOperands.push_back(sourceSlice); SmallVector outerDimOffsets, outerDimSizes; - if (failed(getIterationDomainTileFromOperandTile( - op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets, + if (failed(getIterationDomainTileFromOperandTiles( + op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets, outerDimSizes))) return failure(); @@ -1182,12 +1232,21 @@ struct UnPackOpTiling /// Method to return the position of iteration domain tile computed by the /// tiled operation. - LogicalResult getIterationDomainTileFromOperandTile( - Operation *op, OpBuilder &b, unsigned operandNumber, - ArrayRef offsets, ArrayRef sizes, + LogicalResult getIterationDomainTileFromOperandTiles( + Operation *op, OpBuilder &b, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes, SmallVectorImpl &resultOffsets, SmallVectorImpl &resultSizes) const { + if (operandNumbers.size() != 1) { + LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; }); + return failure(); + } auto unPackOp = cast(op); + unsigned operandNumber = operandNumbers[0]; + ArrayRef offsets(allOffsets[0]); + ArrayRef sizes(allSizes[0]); + // If the operand tile is the dest, then no adjustment is needed. if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) { resultOffsets = llvm::to_vector(offsets); @@ -1241,10 +1300,18 @@ struct UnPackOpTiling } /// Method to return the tiled implementation of tensor.unpack as a consumer. - FailureOr getTiledImplementationFromOperandTile( - Operation *op, OpBuilder &b, unsigned operandNumber, - ArrayRef offsets, ArrayRef sizes) const { + FailureOr getTiledImplementationFromOperandTiles( + Operation *op, OpBuilder &b, ArrayRef operandNumbers, + ArrayRef> allOffsets, + ArrayRef> allSizes) const { + if (operandNumbers.size() != 1 || operandNumbers[0] != 0) { + LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; }); + return failure(); + } auto unPackOp = cast(op); + ArrayRef offsets(allOffsets[0]); + ArrayRef sizes(allSizes[0]); + // tensor.unpack op is fusible (as a consumer) only if inner dims are not // tiled. int64_t numTiles = unPackOp.getInnerDimsPos().size(); @@ -1259,8 +1326,8 @@ struct UnPackOpTiling // Fetch offset/size for creating the slice of the dest operand of // unpack op. SmallVector outputOffsets, outputSizes; - if (failed(getIterationDomainTileFromOperandTile( - op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets, + if (failed(getIterationDomainTileFromOperandTiles( + op, b, operandNumbers, allOffsets, allSizes, outputOffsets, outputSizes))) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index ddcae8481a5b4..995120ad8680e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -2047,53 +2047,119 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, /// A utility to fetch an untiled consumer of /// tensor.insert_slice/tensor.parallel_insert_slice. -static FailureOr -getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp, - MutableArrayRef loops) { +static FailureOr> getUntiledConsumerOperandsFromSlices( + RewriterBase &rewriter, ArrayRef sliceOps, + MutableArrayRef loops) { assert(!loops.empty() && "unexpected empty loops"); - if (auto insertSlice = dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(rewriter, insertSlice, loops); - } else if (auto parallelInsertSlice = - dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops); - } else { - return failure(); + assert(!sliceOps.empty() && "unexpected empty list of candidate slices"); + SmallVector fusedOperands; + for (auto sliceOp : sliceOps) { + FailureOr fusedOperand = + TypeSwitch>(sliceOp) + .Case( + [&](auto op) { + return getUntiledConsumerFromSlice(rewriter, op, loops); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "unhandled slice type"); + }); + if (failed(fusedOperand)) { + return failure(); + } + if (!fusedOperands.empty() && + fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) { + return rewriter.notifyMatchFailure( + fusedOperand.value()->getOwner(), + "all candidate slices must be to the same consumer"); + } + fusedOperands.push_back(fusedOperand.value()); } + return fusedOperands; +} + +template +static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, + InsertSliceOpTy sliceOp); + +template <> +tensor::InsertSliceOp +cloneAsInsertSlice(RewriterBase &rewriter, + tensor::InsertSliceOp insertSliceOp) { + return cast( + rewriter.clone(*insertSliceOp.getOperation())); +} + +template <> +tensor::InsertSliceOp cloneAsInsertSlice( + RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) { + return rewriter.create( + insertSliceOp->getLoc(), insertSliceOp.getSource(), + insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); +} + +static SmallVector +cloneAsInsertSlices(RewriterBase &rewriter, + ArrayRef candidateSlices) { + assert(!candidateSlices.empty() && + "unexpected empty list of slices to clone"); + SmallVector clonedSlices; + for (auto sliceOp : candidateSlices) { + TypeSwitch(sliceOp) + .Case( + [&](auto op) { + auto clonedOp = cloneAsInsertSlice(rewriter, op); + clonedSlices.push_back(clonedOp); + }) + .Default([&](Operation *op) { + // Assert here assuming this has already been checked. + assert(0 && "unexpected slice type while cloning as insert slice"); + }); + } + return clonedSlices; } /// Implementation of fusing consumer of a single slice by computing the /// slice of the consumer in-place for scf loop. FailureOr -mlir::scf::tileAndFuseConsumerOfSlice( - RewriterBase &rewriter, Operation *candidateSliceOp, +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef candidateSlices, MutableArrayRef loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } // Return if `loops` is empty, return an error for now. Caller is expected // to handle this case. if (loops.empty()) { - return candidateSliceOp->emitOpError( + return rewriter.notifyMatchFailure( + candidateSlices.front(), "cannot call tile and fuse consumer with an empty loop nest"); } - if (!isa( - candidateSliceOp)) - return failure(); + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred) || + llvm::all_of(candidateSlices, + llvm::IsaPred))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } // 1. Get the consumer of scf.for for the result yielded by // tensor.insert_slice/parallel_insert_slice. - FailureOr maybeConsumerOpOperand = - getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSliceOp, - "could not fetch consumer to fuse"); - } - OpOperand *consumerOpOperand = *maybeConsumerOpOperand; - Operation *consumerOp = consumerOpOperand->getOwner(); - unsigned operandNumber = consumerOpOperand->getOperandNumber(); - unsigned resultNumber = 0; - if (auto producerResult = dyn_cast(consumerOpOperand->get())) { - resultNumber = producerResult.getResultNumber(); - } else { - return rewriter.notifyMatchFailure( - consumerOp, "consumer op's operand doesn't seem to be an OpResult"); + SmallVector consumerOpOperands; + Operation *consumerOp; + { + FailureOr> maybeConsumerOpOperand = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); + consumerOp = consumerOpOperands.front()->getOwner(); } LoopLikeOpInterface outerMostLoop = loops.front(); @@ -2113,16 +2179,14 @@ mlir::scf::tileAndFuseConsumerOfSlice( if (!dstOp) return rewriter.notifyMatchFailure(consumerOp, "consumer op is not DPS operation"); - SmallVector dpsInits = - llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); - if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { + if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) { + return dstOp.isDpsInit(opOperand); + })) { return rewriter.notifyMatchFailure( consumerOp, "consumer op taking the result of scf.for as init is not supported"); } - SmallVector newInits = dpsInits; - - Location loc = outerMostLoop->getLoc(); + SmallVector newInits = llvm::to_vector(dstOp.getDpsInits()); // 3. Move the whole loop structure right before firstUserOfLoop, the // dominance should be already ensured by `checkAssumptionForLoop`. @@ -2137,43 +2201,52 @@ mlir::scf::tileAndFuseConsumerOfSlice( // tensor.insert_slice. In the scf.for case this is a clone of the // candidateSliceOp whereas in the scf.forall case this is created from the // operands of tensor.parallel_insert_slice. - tensor::InsertSliceOp clonedInsertSliceOp; if (auto sliceOp = - dyn_cast(candidateSliceOp)) { + dyn_cast(candidateSlices.front())) { auto newForallOp = cast(innerMostLoop.getOperation()); rewriter.setInsertionPoint(newForallOp.getTerminator()); - clonedInsertSliceOp = rewriter.create( - loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); } else { - rewriter.setInsertionPoint(candidateSliceOp); - clonedInsertSliceOp = - cast(rewriter.clone(*candidateSliceOp)); + rewriter.setInsertionPoint(candidateSlices.front()); } + // 5.a. Clone all the candidate slices as equivalent insert slice ops. + SmallVector clonedInsertSlices = + cloneAsInsertSlices(rewriter, candidateSlices); - // 5.a. Clone consumer op. + // 5.b. Clone consumer op. auto clonedConsumerOp = cast(rewriter.clone(*consumerOp)); + SmallVector operandNumbers = + llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) { + return opOperand->getOperandNumber(); + }); + SmallVector clonedOpFusedOperandsList = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { + return &clonedConsumerOp->getOpOperand(operandNum); + }); - // 5.b. Replace all uses of the loop result with the result of the cloned + // 5.c. Replace all uses of the loop result with the result of the cloned // tensor.insert_slice. - OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { - operandToReplace.set(clonedInsertSliceOp.getResult()); + for (auto [operandToReplace, clonedSliceOp] : + llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) { + operandToReplace->set(clonedSliceOp.getResult()); + } }); // 6. Perform tiling of the cloned consumer and replace the operand at // `operandNumber` with the source of the cloned tensor.insert_slice op. - auto ossSliceOp = - cast(clonedInsertSliceOp.getOperation()); FailureOr tileAndFuseResult = - tensor::replaceInsertSliceWithTiledConsumer( - rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); + tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices, + clonedOpFusedOperandsList); if (failed(tileAndFuseResult)) { return failure(); } + auto tiledConsumerOp = cast(tileAndFuseResult->tiledOps[0]); - rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), - clonedInsertSliceOp.getSource()); + for (auto [operandNum, clonedSliceOp] : + llvm::zip_equal(operandNumbers, clonedInsertSlices)) { + rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum), + clonedSliceOp.getSource()); + } // 7. Reconstruct [nested] loop with new inits. YieldTiledValuesFn newYieldValuesFn = @@ -2185,14 +2258,20 @@ mlir::scf::tileAndFuseConsumerOfSlice( // 8. Set inner insertPoint right before tiled consumer op. innerRewriter.setInsertionPoint(tiledConsumerOp); - SmallVector offsets = ossSliceOp.getMixedOffsets(); - SmallVector sizes = ossSliceOp.getMixedSizes(); - SmallVector strides = ossSliceOp.getMixedStrides(); + SmallVector> allOffsets, allSizes; + for (auto candidateSliceOp : clonedInsertSlices) { + SmallVector offsets = candidateSliceOp.getMixedOffsets(); + SmallVector sizes = candidateSliceOp.getMixedSizes(); + SmallVector strides = candidateSliceOp.getMixedStrides(); - // 9. Check all insert stride is 1. - if (!llvm::all_of(strides, isOneInteger)) { - return rewriter.notifyMatchFailure( - candidateSliceOp, "containingOp's result yield with stride"); + // 9. Check all insert stride is 1. + if (!llvm::all_of(strides, isOneInteger)) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "containingOp's result yield with stride"); + } + + allOffsets.emplace_back(std::move(offsets)); + allSizes.emplace_back(std::move(sizes)); } // 10. Try to get iter domain position from input position. Use @@ -2202,8 +2281,8 @@ mlir::scf::tileAndFuseConsumerOfSlice( // tiledConsumerOp could lead to some chained unnecessary extra index // computation. SmallVector iterDomainOffsets, iterDomainSizes; - if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( - rewriter, operandNumber, offsets, sizes, iterDomainOffsets, + if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles( + rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets, iterDomainSizes))) { return rewriter.notifyMatchFailure( clonedConsumerOp, @@ -2279,10 +2358,13 @@ mlir::scf::tileAndFuseConsumerOfSlice( // 16. Need to erase the old scf loop and the cloned consumer op. rewriter.eraseOp(clonedConsumerOp); + SmallVector tiledAndFusedOpOperands = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { + return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); + }); return scf::SCFFuseConsumerOfSliceResult{ - consumerOpOperand, - &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), - tileAndFuseResult->tiledOps}; + std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(tileAndFuseResult->tiledOps)}; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 6f33f9b55ceb6..4392a2c0eb839 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -17,6 +17,9 @@ #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/TilingInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tensor-swap-slices" using namespace mlir; @@ -39,21 +42,55 @@ FailureOr tensor::replaceExtractSliceWithTiledProducer( return *tiledResult; } -FailureOr tensor::replaceInsertSliceWithTiledConsumer( - OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, - OpOperand &consumer) { - auto consumerOp = dyn_cast(consumer.getOwner()); +FailureOr tensor::replaceInsertSlicesWithTiledConsumer( + OpBuilder &builder, ArrayRef sliceOps, + ArrayRef consumerOperands) { + if (sliceOps.empty()) { + LLVM_DEBUG( + { llvm::dbgs() << "expected candidate slices list to be non-empty"; }); + return failure(); + } + if (sliceOps.size() != consumerOperands.size()) { + LLVM_DEBUG({ + llvm::dbgs() + << "expected as many operands as the number of slices passed"; + }); + return failure(); + } + auto consumerOp = + dyn_cast(consumerOperands.front()->getOwner()); if (!consumerOp) return failure(); + for (auto opOperand : consumerOperands.drop_front()) { + if (opOperand->getOwner() != consumerOp) { + LLVM_DEBUG({ + llvm::dbgs() + << "expected all consumer operands to be from the same operation"; + }); + return failure(); + } + } - // `TilingInterface` currently only supports strides being 1. - if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) - return failure(); + auto consumerOperandNums = llvm::map_to_vector( + consumerOperands, [](OpOperand *opOperand) -> unsigned { + return opOperand->getOperandNumber(); + }); + SmallVector> allOffsets; + SmallVector> allSizes; + for (auto sliceOp : sliceOps) { + + // `TilingInterface` currently only supports strides being 1. + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) + return failure(); + SmallVector offsets = sliceOp.getMixedOffsets(); + SmallVector sizes = sliceOp.getMixedSizes(); + allOffsets.emplace_back(std::move(offsets)); + allSizes.emplace_back(std::move(sizes)); + } FailureOr tiledResult = - consumerOp.getTiledImplementationFromOperandTile( - builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes()); + consumerOp.getTiledImplementationFromOperandTiles( + builder, consumerOperandNums, allOffsets, allSizes); if (failed(tiledResult)) return failure(); diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index 572a2ae70e0a4..5bdb5073ee865 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -653,6 +653,7 @@ module { %5 = affine.min #map2(%i)[%d0, %idx] %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor to tensor + // CHECK: linalg.generic // CHECK: %[[T1:.*]] = linalg.generic {{.*}} // CHECK: %[[T2:.*]] = linalg.generic {{.*}} %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor to tensor diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 77e52946b830f..0f69875d596f1 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s #map = affine_map<(d0) -> (d0)> module { @@ -620,3 +620,294 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @multi_slice_fusion1(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, %arg3 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor to tensor + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor to tensor + %generic:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice, %init1_slice : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.mulf %b0, %b1 : f32 + %1 = arith.addf %b0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor, tensor) + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor + tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor into tensor + } + } + %empty = tensor.empty(%dim0) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +// CHECK-LABEL: func @multi_slice_fusion1( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// Check that when the given operand tiles are inconsistent, tiling fails. + +func.func @multi_slice_fusion2(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, %arg3 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor to tensor + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor to tensor + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init1_slice: tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor into tensor + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor into tensor + } + } + %empty = tensor.empty(%dim0) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +// CHECK-LABEL: func @multi_slice_fusion2( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, + %arg3 : index, %arg4 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %dim2 = tensor.dim %arg0, %c2 : tensor + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor to tensor + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor to tensor + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%generic0 : tensor) outs(%init1_slice: tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor into tensor + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor into tensor + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @multi_slice_fusion_with_broadcast( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 +// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]]) +// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]]) +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: return %[[RESULT]]#2 + +// ----- + +func.func @multi_slice_fusion_invalid(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor, + %arg3 : index, %arg4 : index) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg0, %c1 : tensor + %dim2 = tensor.dim %arg0, %c2 : tensor + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor, tensor) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor to tensor + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor to tensor + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init0_slice : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor to tensor + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor) outs(%init1_slice: tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor into tensor + tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor into tensor + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor, tensor) outs(%empty : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor + return %result : tensor +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 9971f0cde4ed2..ee3eb9522db7e 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -21,6 +21,9 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/TilingInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "test-tiling-interface" #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.h.inc" @@ -168,29 +171,30 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. -template static LogicalResult applyFuseConsumer( - RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, - MutableArrayRef loops, uint32_t numConsumerToFuse, - TransformResults &transformResults) { + RewriterBase &rewriter, Operation *transformOp, + ArrayRef slices, MutableArrayRef loops, + uint32_t numConsumerToFuse, TransformResults &transformResults) { SmallVector originalConsumerOps; SmallVector fusedConsumerOps; - for (Operation *target : payloadOps) { - rewriter.setInsertionPoint(target); + rewriter.setInsertionPoint(slices.front()); - while (numConsumerToFuse--) { - FailureOr fuseConsumerResults = - scf::tileAndFuseConsumerOfSlice(rewriter, target, loops); + while (numConsumerToFuse--) { + FailureOr fuseConsumerResults = + scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops); - if (failed(fuseConsumerResults)) - return failure(); + if (failed(fuseConsumerResults)) + return slices.front()->emitOpError("failed to fuse consumer of slice"); - // Report back the relevant handles to the transform op. - originalConsumerOps.push_back( - fuseConsumerResults->origConsumerOperand->getOwner()); - fusedConsumerOps.push_back( - fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); + // Report back the relevant handles to the transform op. + for (OpOperand *origConsumerOperand : + fuseConsumerResults->origConsumerOperands) { + originalConsumerOps.push_back(origConsumerOperand->getOwner()); + } + for (OpOperand *tiledAndFusedConsumerOperand : + fuseConsumerResults->tiledAndFusedConsumerOperands) { + fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); } } @@ -203,6 +207,12 @@ DiagnosedSilenceableFailure transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { + SmallVector slices; + for (auto op : getTargets()) { + auto sliceOp = *state.getPayloadOps(op).begin(); + slices.push_back(sliceOp); + } + SmallVector loops; for (auto op : llvm::reverse(getLoops())) { auto loopLikeOp = @@ -212,16 +222,16 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, } loops.push_back(loopLikeOp); } - LogicalResult result = applyFuseConsumer( - rewriter, getOperation(), state.getPayloadOps(getTarget()), loops, - getNumConsumerToFuse(), transformResults); + LogicalResult result = + applyFuseConsumer(rewriter, getOperation(), slices, loops, + getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } void transform::TestFuseConsumerOp::getEffects( SmallVectorImpl &effects) { - consumesHandle(getTargetMutable(), effects); + consumesHandle(getTargetsMutable(), effects); consumesHandle(getLoopsMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 98f7145c99cb1..3c09082e192ea 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -50,7 +50,8 @@ def TestFuseAndYieldOp : Op, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, ReportTrackingListenerFailuresOpTrait]> { let description = [{ @@ -59,14 +60,14 @@ def TestFuseConsumerOp : Op:$targets, Variadic:$loops, DefaultValuedAttr:$num_consumer_to_fuse); let results = (outs TransformHandleTypeInterface:$consumer, TransformHandleTypeInterface:$fused_consumer); let assemblyFormat = [{ - $target `in` `(` $loops `)` + $targets `in` `(` $loops `)` (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? attr-dict `:` functional-type(operands, results) }];