@@ -953,49 +953,122 @@ mlir::scf::tileAndFuseProducerOfSlice(
953953LogicalResult mlir::scf::yieldReplacementForFusedProducer (
954954 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
955955 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
956- MutableArrayRef<LoopLikeOpInterface> loops) {
956+ MutableArrayRef<LoopLikeOpInterface> loops,
957+ ArrayRef<unsigned > yieldResultNumber) {
957958 if (loops.empty ())
958959 return success ();
959960
960- OpResult fusableProducer = fusedProducerInfo.origProducer ;
961- Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer ;
962- FailureOr<Value> initValue = tensor::getOrCreateDestination (
963- rewriter, fusableProducer.getOwner ()->getLoc (), fusableProducer);
964- if (succeeded (initValue)) {
965-
966- YieldTiledValuesFn newYieldValuesFn =
967- [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
968- ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
969- SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
970- SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
971- -> LogicalResult {
972- OpBuilder::InsertionGuard g (innerRewriter);
973- if (auto tiledDestStyleOp =
974- tiledAndFusedProducer
975- .getDefiningOp <DestinationStyleOpInterface>()) {
976- rewriter.setInsertionPoint (tiledDestStyleOp);
977- Value newRegionArg = newRegionIterArgs.back ();
961+ Operation *originalOwner = fusedProducerInfo.origProducer .getOwner (),
962+ *tiledOwner = fusedProducerInfo.tiledOps [0 ];
963+
964+ Location loc = originalOwner->getLoc ();
965+ // a. collect all init Value to be appended
966+ SmallVector<unsigned > initNumberList =
967+ yieldResultNumber.empty () ? llvm::to_vector (llvm::seq<unsigned >(
968+ 0 , originalOwner->getNumResults ()))
969+ : llvm::to_vector (yieldResultNumber);
970+ SmallVector<Value> initValueList;
971+ for (const auto &resultNumber : initNumberList) {
972+ FailureOr<Value> initValue = tensor::getOrCreateDestination (
973+ rewriter, loc, originalOwner->getResult (resultNumber));
974+ if (succeeded (initValue)) {
975+ initValueList.push_back (initValue.value ());
976+ } else {
977+ return failure ();
978+ }
979+ }
980+
981+ YieldTiledValuesFn newYieldValuesFn =
982+ [&](RewriterBase &innerRewriter, Location loc, ValueRange /* ivs*/ ,
983+ ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
984+ SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
985+ SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
986+ OpBuilder::InsertionGuard g (innerRewriter);
987+
988+ // get sliceOp tile information
989+ SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets (),
990+ sliceSizes = sliceOp.getMixedSizes ();
991+
992+ // expect all strides of sliceOp being 1
993+ if (llvm::any_of (sliceOp.getMixedStrides (), [](OpFoldResult ofr) {
994+ return !isConstantIntValue (ofr, 1 );
995+ }))
996+ return failure ();
997+
998+ unsigned sliceResultNumber =
999+ fusedProducerInfo.origProducer .getResultNumber ();
1000+
1001+ auto tilableOp = cast<TilingInterface>(originalOwner);
1002+ // b. get iterDomain Offset and Sizes based on sliceOp tile
1003+ SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1004+ // skip tensor.pack/unpack/pad, which expects single opResult
1005+ if (tilableOp->getNumResults () > 1 &&
1006+ failed (tilableOp.getIterationDomainTileFromResultTile (
1007+ rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1008+ iterDomainOffset, iterDomainSizes))) {
1009+ // In theory, it is unnecessary to raise an error here. Actually although
1010+ // it fails to reconstruct the result tensor, it should not broke current
1011+ // fusion anyway. The reason why we must return failure currently is that
1012+ // the callback function `newYieldValuesFn` will be called after new init
1013+ // operand(s) has already been appended. It will take more refactoring to
1014+ // make sure the init operands are added consistently in the future. For
1015+ // more details, please refer to:
1016+ // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1017+ return failure ();
1018+ }
1019+
1020+ // c. calculate offsets and sizes info of all OpResults respectively based
1021+ // on iteration Domain Tile
1022+ SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1023+ for (const auto &resultNumber : initNumberList) {
1024+ if (resultNumber == sliceResultNumber) {
1025+ offsetList.push_back (sliceOffset);
1026+ sizesList.push_back (sliceSizes);
1027+ } else {
1028+ assert (!iterDomainOffset.empty () && !iterDomainSizes.empty ());
1029+ // infer result tile according to the iteration domain tile
1030+ SmallVector<OpFoldResult> offset, sizes;
1031+ if (failed (tilableOp.getResultTilePosition (
1032+ rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1033+ offset, sizes))) {
1034+ return failure ();
1035+ }
1036+ offsetList.push_back (offset);
1037+ sizesList.push_back (sizes);
1038+ }
1039+ }
1040+
1041+ // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1042+ if (auto tiledDestStyleOp =
1043+ dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1044+ rewriter.setInsertionPoint (tiledDestStyleOp);
1045+ for (const auto &&[index, newRegionArg] :
1046+ llvm::enumerate (newRegionIterArgs)) {
9781047 auto destSlice = rewriter.create <tensor::ExtractSliceOp>(
979- sliceOp.getLoc (), newRegionArg, sliceOp.getMixedOffsets (),
980- sliceOp.getMixedSizes (), sliceOp.getMixedStrides ());
981- unsigned resultNumber = fusableProducer.getResultNumber ();
1048+ loc, newRegionArg, offsetList[index], sizesList[index],
1049+ SmallVector<OpFoldResult>(offsetList[index].size (),
1050+ rewriter.getIndexAttr (1 )));
1051+ unsigned resultNumber = initNumberList[index];
9821052 rewriter.modifyOpInPlace (tiledDestStyleOp, [&]() {
9831053 tiledDestStyleOp.getDpsInitsMutable ()[resultNumber].set (destSlice);
9841054 });
9851055 }
986- Block *block = rewriter.getInsertionPoint ()->getBlock ();
987- rewriter.setInsertionPoint (block->getTerminator ());
988- tiledResult.push_back (fusedProducerInfo.tiledAndFusedProducer );
989- tiledOffset.emplace_back (sliceOp.getMixedOffsets ());
990- tiledSizes.emplace_back (sliceOp.getMixedSizes ());
991- return success ();
992- };
1056+ }
9931057
994- return addInitOperandsToLoopNest (rewriter, loops,
995- SmallVector<Value>{initValue.value ()},
996- newYieldValuesFn);
997- }
998- return success ();
1058+ // e. prepare tiled offset and sizes for later `insert_slice` creation by
1059+ // caller
1060+ Block *block = rewriter.getInsertionPoint ()->getBlock ();
1061+ rewriter.setInsertionPoint (block->getTerminator ());
1062+ for (const auto &&[index, resultNumber] : llvm::enumerate (initNumberList)) {
1063+ tiledResult.push_back (tiledOwner->getResult (resultNumber));
1064+ tiledOffset.emplace_back (offsetList[index]);
1065+ tiledSizes.emplace_back (sizesList[index]);
1066+ }
1067+ return success ();
1068+ };
1069+
1070+ return addInitOperandsToLoopNest (rewriter, loops, initValueList,
1071+ newYieldValuesFn);
9991072}
10001073
10011074// / Implementation of tile consumer and fuse producer greedily.
@@ -1085,14 +1158,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
10851158 continue ;
10861159
10871160 if (yieldReplacement) {
1161+ // Reconstruct and yield all opResult of fusableProducerOp by default. The
1162+ // caller can specific which one to yield by designating optional argument
1163+ // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1164+ Operation *fusableProducerOp = fusableProducer.getOwner ();
10881165 if (failed (yieldReplacementForFusedProducer (
10891166 rewriter, candidateSliceOp, fusedResult.value (), loops))) {
10901167 return rewriter.notifyMatchFailure (
1091- fusableProducer.getOwner (), " failed to replacement value for this "
1092- " oepration from within the tiled loop" );
1168+ fusableProducerOp, " failed to replacement value for this "
1169+ " operation from within the tiled loop" );
1170+ }
1171+ for (auto [index, result] :
1172+ llvm::enumerate (fusableProducerOp->getResults ())) {
1173+ origValToResultNumber[result] = loops.front ()->getNumResults () -
1174+ fusableProducerOp->getNumResults () +
1175+ index;
10931176 }
1094- origValToResultNumber[fusableProducer] =
1095- loops.front ()->getNumResults () - 1 ;
10961177 }
10971178
10981179 if (Operation *tiledAndFusedOp =
0 commit comments