diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e2feb10b31454..02e58141bdc30 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -1580,33 +1582,163 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { return success(); } -/// Fetches the OpOperand of the only user (and use) of the value `val` which -/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns -/// failure otherwise. -static FailureOr getConsumerFromUses(Value val, - Block *containingOpBlock) { - // Check that the value has exactly one use which isn't a scf.yield or a - // tensor.parallel_insert_slice op. - OpOperand *operand = nullptr; - for (OpOperand &opOperand : val.getUses()) { - Operation *consumerOp = opOperand.getOwner(); - if (isa(consumerOp)) - continue; - if (operand) - return failure(); - // TODO: We have to init result of consumer before scf.for, use - // DestinationStyleOpInterface to get result shape from init for now. - // Add support for other op such as op has InferTypeOpInterface. - if (!isa(consumerOp) || - !isa(consumerOp)) +/// An utility to get the first user of the given loopOp. If any of user stay in +/// different block of loopOp, return failure. +static FailureOr getFirstUserOfLoop(Operation *loopOp) { + if (!isa(loopOp)) + return failure(); + Operation *firstUserOfLoop = nullptr; + for (Operation *userOp : loopOp->getUsers()) { + // `ParallelInsertSlice` located inside `InParallelOp` has no same parent + // block with any other types of operation. Thus, just redirecting to its + // parent `InParallelOp`. E.g. + // + // ``` + // %1 = scf.for { + // ... + // } + // %2 = consumerOp ins(%1, ...) + // scf.forall.in_parallel { + // tensor.parallel_insert_slice %1 + // } + // ``` + // where `InParallelOp` but not `ParallelInsertSlice` stays in the same + // same block with `consumerOp`. + if (isa(userOp)) + userOp = userOp->getParentOfType(); + + if (loopOp->getBlock() != userOp->getBlock()) return failure(); - if (containingOpBlock != consumerOp->getBlock()) + + if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop)) + firstUserOfLoop = userOp; + } + return firstUserOfLoop; +} + +/// This utility currently checks whether the first userOp of loop is NOT before +/// the last defineOp of consumer operand. Because that we need to move the +/// whole loop structure right before the `firstUserOfLoop`. This utility thus +/// helps ensuring that no invalid IR is formed, i.e. no backward slice of +/// consumerOp is dominated by the `firstUserOfLoop`. Saying that: +/// +/// ``` +/// %0 = scf.for() { +/// ... +/// } +/// ... +/// %1 = firstUserOfLoop(%0) +/// ... +/// %2 = lastDefOfConsumerOperand +/// ... +/// %3 = consumerOp(%2) +/// ``` +/// +/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would +/// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a. +/// use-def chain violation: +/// +/// ``` +/// %0:2 = scf.for() { +/// // use before define error +/// %3 = tiledConsumerOp(%2) +/// } +/// %1 = firstUserOfLoop(%0) +/// ... +/// %2 = lastDefOfConsumerOperand +/// ``` +/// +/// @param loopOp: loop operation +/// @param consumerOp: consumer operation +/// @param reorderOperations: the flag controls whether to reorder the backward +/// slice w.r.t. the defineOp of `consumerOp` operands. +/// @return: computed backward slice of consumerOp, but excluding those already +/// dominates `firstUserOfLoop`. +static FailureOr> +checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, + bool reorderOperations) { + FailureOr firstUserOfLoop = getFirstUserOfLoop(loopOp); + if (failed(firstUserOfLoop)) + return failure(); + + BackwardSliceOptions options; + DominanceInfo dominanceInfo; + options.inclusive = true; + options.omitBlockArguments = true; + bool includeLoopOp = false; + options.filter = [&](Operation *op) { + if (op == loopOp) { + includeLoopOp = true; + return false; + } + // Cut off the slice to not include any operation that already dominates + // firstUserOfLoop. + return !dominanceInfo.properlyDominates(op, *firstUserOfLoop); + }; + llvm::SetVector slice; + for (auto operand : consumerOp->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + + if (!slice.empty()) { + // If consumerOp has one producer, which is also the user of loopOp. + // E.g. + // ``` + // %0 = %loopOp + // %1 = consumerOp1 ins(%0) + // %2 = consumerOp2 ins(%0, %1) + // ``` + // We can not fuse consumerOp2 into loopOp due to UD chain, unless + // consumerOp1 has already been fused into loopOp before. + if (includeLoopOp || !reorderOperations) return failure(); - operand = &opOperand; } - if (operand) - return operand; + return slice; +} + +/// Fetches the OpOperand of the first valid user (and use) of the value `val` +/// which implements `TilingInterface` and `DestinationStyleOpInterface`. +/// Returns failure otherwise. +static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, + Operation *loopOp, + unsigned resultNumber) { + if (!isa(loopOp)) + return failure(); + Value val = loopOp->getResult(resultNumber); + Block *loopBlock = loopOp->getBlock(); + for (OpOperand &opOperand : val.getUses()) { + Operation *consumerOp = opOperand.getOwner(); + // Step 1. Check if the user is tilable. + if (!isa(consumerOp)) { + // TODO: We have to init result of consumer before scf.for, use + // DestinationStyleOpInterface to get result shape from init for now. Add + // support for other op such as op has InferTypeOpInterface. + continue; + } + // Step 2. Check if user stay in the same block. + if (loopBlock != consumerOp->getBlock()) + continue; + // Step 3. Check if user has succeeding user. Otherwise, it usually + // represents already tiled. + if (consumerOp->use_empty()) + continue; + // Step 4. Check assumption for loop with `reorderOperations` enabled. + FailureOr> slice = + checkAssumptionForLoop(loopOp, consumerOp, true); + if (failed(slice)) + continue; + // Step 5. If backward sice is not empty, move them before firstUserOfLoop. + if (!slice->empty()) { + mlir::topologicalSort(*slice); + FailureOr firstUserOfLoop = getFirstUserOfLoop(loopOp); + assert(succeeded(firstUserOfLoop) && "First user of loop is not found"); + for (auto op : *slice) { + rewriter.moveOpBefore(op, *firstUserOfLoop); + } + } + return &opOperand; + } return failure(); } @@ -1659,7 +1791,8 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { /// 1. tensor.insert_slice has scf.yield as its only user. /// 2. scf.for's corresponding result has only one use. static FailureOr -getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { +getUntiledConsumerFromSlice(RewriterBase &rewriter, + tensor::InsertSliceOp candidateSliceOp) { if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) return failure(); Value sliceResult = candidateSliceOp.getResult(); @@ -1672,15 +1805,15 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { if (!forOp) return failure(); scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); - Value resultingValue = topLevelForOp->getResult(resultNumber); - return getConsumerFromUses(resultingValue, topLevelForOp->getBlock()); + return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); } /// Fetch the first untiled consumer of a scf.forall's result which is yielded /// by a tensor.parallel_insert_slice. static FailureOr -getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) { +getUntiledConsumerFromSlice(RewriterBase &rewriter, + tensor::ParallelInsertSliceOp candidateSliceOp) { // Step 1. Fetch the corresponding output Value sliceDest = candidateSliceOp.getDest(); auto iterArg = dyn_cast(sliceDest); @@ -1693,45 +1826,22 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) { auto forallOp = dyn_cast(containingOp); if (!forallOp) return failure(); - Value resultingValue = - forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); - - return getConsumerFromUses(resultingValue, containingOp->getBlock()); -} + unsigned resultNumber = + forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) + .getResultNumber(); -/// This utility currently checks whether the loop either :- -/// 1. Yields exactly one result. -/// 2. Has consumer op as its first user and other users to be in the same -/// containing block as that of consumer op's. Currently we clone the loop op -/// right before the consumer op in order to maintain a valid def-use chain. -/// This utility thus helps ensuring that no invalid IR is formed due to the -/// same. -static LogicalResult checkAssumptionForLoop(Operation *loopOp, - Operation *consumerOp) { - // Check if the loop op yields one result. - if (loopOp->getNumResults() == 1) - return success(); - // Check if the consumerOp is the first user of the loopOp and if other users - // are in the same containing block as that of consumer op's. - Block *parentBlock = consumerOp->getBlock(); - for (Operation *userOp : loopOp->getUsers()) { - if (userOp == consumerOp) - continue; - if (parentBlock != userOp->getBlock() || - !consumerOp->isBeforeInBlock(userOp)) - return failure(); - } - return success(); + return getConsumerFromLoopUses(rewriter, containingOp, resultNumber); } /// A utility to fetch an untiled consumer of /// tensor.insert_slice/tensor.parallel_insert_slice. -static FailureOr getUntiledConsumerFromSlice(Operation *sliceOp) { +static FailureOr +getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { if (auto insertSlice = dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(insertSlice); + return getUntiledConsumerFromSlice(rewriter, insertSlice); } else if (auto parallelInsertSlice = dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(parallelInsertSlice); + return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice); } else { return failure(); } @@ -1751,7 +1861,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, // 1. Get the consumer of scf.for for the result yielded by // tensor.insert_slice/parallel_insert_slice. FailureOr maybeConsumerOpOperand = - getUntiledConsumerFromSlice(candidateSliceOp); + getUntiledConsumerFromSlice(rewriter, candidateSliceOp); if (failed(maybeConsumerOpOperand)) { return rewriter.notifyMatchFailure(candidateSliceOp, "could not fetch consumer to fuse"); @@ -1787,11 +1897,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, LoopLikeOpInterface outerMostLoop = nestedLoops.front(); - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) { + // Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { return rewriter.notifyMatchFailure( - outerMostLoop, - "containing loop op should either yield just one value or " - "have the consumer op as its first user"); + outerMostLoop, "the first user of loop should not dominate any define " + "of consumer operand(s)"); } OpBuilder::InsertionGuard g(rewriter); @@ -1812,9 +1922,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Location loc = outerMostLoop->getLoc(); - // 3. Move the whole loop structure right before consumer Op, the dominance - // should be already ensured by `checkAssumptionForLoop`. - rewriter.moveOpBefore(outerMostLoop, consumerOp); + // 3. Move the whole loop structure right before firstUserOfLoop, the + // dominance should be already ensured by `checkAssumptionForLoop`. + FailureOr firstUserOfLoop = getFirstUserOfLoop(outerMostLoop); + if (failed(firstUserOfLoop)) { + return rewriter.notifyMatchFailure( + outerMostLoop, "could not find the first user of outer most loop"); + } + rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop); // 4. Set insertion point before terminator op of the loop and create a new // tensor.insert_slice. In the scf.for case this is a clone of the diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index f5f703d95e2d5..af836d18e8c02 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -508,3 +508,65 @@ module { // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : // CHECK: } // CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 : + +// ----- + +module { + func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2 + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_add_multiple_tilable_consumers( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) +// CHECK-SAME: { +// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] : +// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : +// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp +// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : +// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : +// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul +// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] : +// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] : +// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] : +// CHECK: } +// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 : diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index b6da47977cb4c..5e903e378daf8 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -171,24 +171,27 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, template static LogicalResult applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, - Range &&payloadOps, TransformResults &transformResults) { + Range &&payloadOps, uint32_t numConsumerToFuse, + TransformResults &transformResults) { SmallVector originalConsumerOps; SmallVector fusedConsumerOps; for (Operation *target : payloadOps) { rewriter.setInsertionPoint(target); - FailureOr fuseConsumerResults = - scf::tileAndFuseConsumerOfSlice(rewriter, target); + while (numConsumerToFuse--) { + FailureOr fuseConsumerResults = + scf::tileAndFuseConsumerOfSlice(rewriter, target); - if (failed(fuseConsumerResults)) - return failure(); + if (failed(fuseConsumerResults)) + return failure(); - // Report back the relevant handles to the transform op. - originalConsumerOps.push_back( - fuseConsumerResults->origConsumerOperand->getOwner()); - fusedConsumerOps.push_back( - fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); + // Report back the relevant handles to the transform op. + originalConsumerOps.push_back( + fuseConsumerResults->origConsumerOperand->getOwner()); + fusedConsumerOps.push_back( + fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner()); + } } transformResults.set(transformOp->getOpResult(0), originalConsumerOps); @@ -200,9 +203,9 @@ DiagnosedSilenceableFailure transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { - LogicalResult result = - applyFuseConsumer(rewriter, getOperation(), - state.getPayloadOps(getTarget()), transformResults); + LogicalResult result = applyFuseConsumer( + rewriter, getOperation(), state.getPayloadOps(getTarget()), + getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index d55d746bd6aa9..34b075a5c17f9 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -59,12 +59,14 @@ def TestFuseConsumerOp : Op:$num_consumer_to_fuse); let results = (outs TransformHandleTypeInterface:$consumer, TransformHandleTypeInterface:$fused_consumer); let assemblyFormat = [{ - $target attr-dict `:` functional-type(operands, results) + $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? + attr-dict `:` functional-type(operands, results) }]; }