diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index d2cddfe00ac78..33a43ce2ee7bb 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult { SmallVector tiledOps; }; FailureOr -tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); +tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, + MutableArrayRef loops); /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index af87fb7a79d04..8e407cc1b348f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1846,11 +1846,9 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, return failure(); } -/// Find the perfectly nested loops outside of given loop(included) sorted -/// from outer to inner. -/// -/// E.g. -/// +/// Check that the loop is perfectly nested. +/// The loops are expected to be ordered from outer most to inner most. +/// For example: /// ``` /// %0 = scf.for() /// %1 = scf.for() @@ -1860,55 +1858,85 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, /// yield %2 /// yield %1 /// ``` -/// -/// This function will return three perfectly nested loops: %0 + %1 + %2, when -/// target inner loop is %2. -static SmallVector -getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { - SmallVector nestLoops = {loop}; - auto outerLoop = dyn_cast(loop->getParentOp()); - - // Check if it is the ForOp that yield the result of inner loop. - auto isForOpYieldResultOfInnerLoop = - [](scf::ForOp outerLoop) -> LogicalResult { - Block *body = outerLoop.getBody(); - if (!llvm::hasSingleElement(body->without_terminator())) - return failure(); - auto yieldOp = cast(body->getTerminator()); - auto innerForOp = dyn_cast(body->front()); - if (!innerForOp) - return failure(); - // All of innerForOp results should be yielded. - return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); - }; +/// Here loops should be [%0, %1]. +static bool +isPerfectlyNestedForLoops(MutableArrayRef loops) { + assert(!loops.empty() && "unexpected empty loop nest"); + if (loops.size() == 1) { + return isa_and_nonnull(loops.front().getOperation()); + } + for (auto [outerLoop, innerLoop] : + llvm::zip_equal(loops.drop_back(), loops.drop_front())) { + auto outerFor = dyn_cast_or_null(outerLoop.getOperation()); + auto innerFor = dyn_cast_or_null(innerLoop.getOperation()); + if (!outerFor || !innerFor) { + return false; + } + auto outerBBArgs = outerFor.getRegionIterArgs(); + auto innerIterArgs = innerFor.getInitArgs(); + if (outerBBArgs.size() != innerIterArgs.size()) { + return false; + } + + for (auto [outerBBArg, innerIterArg] : + llvm::zip_equal(outerBBArgs, innerIterArgs)) { + if (!llvm::hasSingleElement(outerBBArg.getUses()) || + innerIterArg != outerBBArg) { + return false; + } + } - while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { - nestLoops.push_back(outerLoop); - outerLoop = dyn_cast(outerLoop->getParentOp()); + ValueRange outerYields = + cast(outerFor.getBody()->getTerminator())->getOperands(); + ValueRange innerResults = innerFor.getResults(); + if (outerYields.size() != innerResults.size()) { + return false; + } + for (auto [outerYield, innerResult] : + llvm::zip_equal(outerYields, innerResults)) { + if (!llvm::hasSingleElement(innerResult.getUses()) || + outerYield != innerResult) { + return false; + } + } } - // sorted from outer to inner - return {nestLoops.rbegin(), nestLoops.rend()}; + return true; } -/// Fetch the untiled consumer of a scf.for's result which is yielded by a -/// tensor.insert_slice. This function makes the following assumptions : -/// 1. tensor.insert_slice has scf.yield as its only user. -/// 2. scf.for's corresponding result has only one use. +/// Fetch the untiled consumer of the outermost scf.for's result which is +/// yielded by a tensor.insert_slice from the innermost scf.for. This function +/// makes the following assumptions : +/// 1. tensor.insert_slice has scf.yield as its only user. +/// 2. scf.for's corresponding result has only one use. +/// 3. The `loops` passed in are perfectly nested `scf.for` operations. static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, - tensor::InsertSliceOp candidateSliceOp) { + tensor::InsertSliceOp candidateSliceOp, + MutableArrayRef loops) { + assert(!loops.empty() && "unexpected loops to be empty"); + // 1. Expect slice to be part of the body of the inner most loop. + Operation *containingOp = candidateSliceOp->getParentOp(); + if (containingOp != loops.back()) { + return rewriter.notifyMatchFailure( + candidateSliceOp, + "expected slice to be within body of inner-most loop"); + } + + // 2. Check that the loop is perfectly nested. + if (!isPerfectlyNestedForLoops(loops)) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "expected passed loops to be perfectly nested."); + } + if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) return failure(); Value sliceResult = candidateSliceOp.getResult(); - // Step 1. Fetch the corresponding output. + + // 3. Fetch the corresponding output. OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); unsigned resultNumber = yieldOpOperand.getOperandNumber(); - // Step 2. Check containing op is scf.for. - Operation *containingOp = candidateSliceOp->getParentOp(); - auto forOp = dyn_cast(containingOp); - if (!forOp) - return failure(); - scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); + + scf::ForOp topLevelForOp = cast(loops.front().getOperation()); return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); } @@ -1917,35 +1945,46 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, /// by a tensor.parallel_insert_slice. static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, - tensor::ParallelInsertSliceOp candidateSliceOp) { - // Step 1. Fetch the corresponding output + tensor::ParallelInsertSliceOp candidateSliceOp, + MutableArrayRef loops) { + assert(!loops.empty() && "unexpected loops to be empty"); + // 1. Check that the surrounding loop is a single scf.forall loop. + if (loops.size() != 1) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "expected single surrounding scf.forall"); + } + auto forallOp = dyn_cast(loops.front().getOperation()); + if (!forallOp) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "expected single surrounding scf.forall"); + } + + // 2. Fetch the corresponding output Value sliceDest = candidateSliceOp.getDest(); auto iterArg = dyn_cast(sliceDest); if (!iterArg) return failure(); - Operation *containingOp = iterArg.getOwner()->getParentOp(); - if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) - return failure(); - // Step 2. Check that the containing op is scf.forall. - auto forallOp = dyn_cast(containingOp); - if (!forallOp) + if (iterArg.getOwner()->getParentOp() != forallOp) return failure(); + unsigned resultNumber = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) .getResultNumber(); - return getConsumerFromLoopUses(rewriter, containingOp, resultNumber); + return getConsumerFromLoopUses(rewriter, forallOp, resultNumber); } /// A utility to fetch an untiled consumer of /// tensor.insert_slice/tensor.parallel_insert_slice. static FailureOr -getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { +getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp, + MutableArrayRef loops) { + assert(!loops.empty() && "unexpected empty loops"); if (auto insertSlice = dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(rewriter, insertSlice); + return getUntiledConsumerFromSlice(rewriter, insertSlice, loops); } else if (auto parallelInsertSlice = dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice); + return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops); } else { return failure(); } @@ -1954,18 +1993,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { /// Implementation of fusing consumer of a single slice by computing the /// slice of the consumer in-place for scf loop. FailureOr -mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, - Operation *candidateSliceOp) { +mlir::scf::tileAndFuseConsumerOfSlice( + RewriterBase &rewriter, Operation *candidateSliceOp, + MutableArrayRef loops) { + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return candidateSliceOp->emitOpError( + "cannot call tile and fuse consumer with an empty loop nest"); + } if (!isa( candidateSliceOp)) return failure(); - bool isInsertSliceOp = isa(candidateSliceOp); - // 1. Get the consumer of scf.for for the result yielded by // tensor.insert_slice/parallel_insert_slice. FailureOr maybeConsumerOpOperand = - getUntiledConsumerFromSlice(rewriter, candidateSliceOp); + getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops); if (failed(maybeConsumerOpOperand)) { return rewriter.notifyMatchFailure(candidateSliceOp, "could not fetch consumer to fuse"); @@ -1981,25 +2025,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, consumerOp, "consumer op's operand doesn't seem to be an OpResult"); } - // There are two possible cases regarding `oldLoopOp` here: - // 1. single `scf.forall` or `scf.for`. - // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the - // top-level loop is the outer-most one of these nested loops. - LoopLikeOpInterface innerMostLoop = - candidateSliceOp->getParentOfType(); - SmallVector nestedLoops; - if (isInsertSliceOp) { - nestedLoops = llvm::map_to_vector( - getPerfectlyNestedLoopsOutsideOf( - cast(innerMostLoop.getOperation())), - [](scf::ForOp forOp) { - return cast(forOp.getOperation()); - }); - } else { - nestedLoops = {innerMostLoop}; - } - - LoopLikeOpInterface outerMostLoop = nestedLoops.front(); + LoopLikeOpInterface outerMostLoop = loops.front(); + LoopLikeOpInterface innerMostLoop = loops.back(); // Check assumption for loop with `reorderOperations` disabled. if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { @@ -2165,7 +2192,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, return success(); }; // 14. Add new inits to [nested] loops. - if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, + if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits, newYieldValuesFn))) { return rewriter.notifyMatchFailure(tiledConsumerOp, "unable to add new inits to nest loop"); @@ -2174,9 +2201,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, // 15. Replace the result of scf loop and consumer op with new loop's // results. - for (auto &&[oldResult, newResult] : llvm::zip( - consumerOp->getResults(), - nestedLoops.front()->getResults().take_back(newInits.size()))) { + for (auto &&[oldResult, newResult] : + llvm::zip(consumerOp->getResults(), + loops.front()->getResults().take_back(newInits.size()))) { rewriter.replaceAllUsesWith(oldResult, newResult); } diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 5d4ae4f15d3fd..185fb9b358055 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,8 +170,8 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op - : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op) + transform.test.fuse_consumer %slice_op in (%forall_op) + : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 8ce05d94c4ad0..77e52946b830f 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -26,10 +26,12 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -83,11 +85,13 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op %first_slice_op, %second_slice_op = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -153,8 +157,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -220,11 +226,13 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op %first_slice_op, %second_slice_op = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -287,8 +295,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -348,8 +358,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -409,8 +421,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -437,143 +451,6 @@ module attributes {transform.with_named_sequence} { // ----- -module { - func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %dest0 = tensor.empty() : tensor<256x256xf32> - %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> - %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) { - %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) { - %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32> - %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32> - %extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32> - %3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32> - %insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32> - scf.yield %insert_slice : tensor<256x256xf32> - } - scf.yield %2 : tensor<256x256xf32> - } - %4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> - return %4 : tensor<256x256xf32> - } -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: func.func @fuse_add_consumer_into_nested_scf_for( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> -// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[dest1:.*]] = linalg.fill -// CHECK-SAME: outs(%[[dest0]] : -// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]]) -// CHECK-SAME: { -// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]]) -// CHECK-SAME: { -// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1] -// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1] -// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul -// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : -// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add -// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] : -// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : -// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] : -// CHECK: } -// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : -// CHECK: } -// CHECK: return %[[LOOP_RESULT1]]#1 : - -// ----- - -// This test case checks fusion of consumer even if the producer has multiple uses. -// The multiple uses of the producer essentially means that besides the consumer -// op in concern, the only other uses of the producer are allowed in :- -// 1. scf.yield -// 2. tensor.parallel_insert_slice - -module { - module { - func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<256x256xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> - %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) { - %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) { - %extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32> - %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32> - %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32> - %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32> - %inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32> - scf.yield %inserted_slice : tensor<256x256xf32> - } - %4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> - scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32> - } - return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32> - } - } - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } - } -} -// CHECK: func.func @fuse_consumer_for_multi_use_producer( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> -// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[dest1:.*]] = linalg.fill -// CHECK-SAME: outs(%[[dest0]] : -// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]]) -// CHECK-SAME: { -// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]]) -// CHECK-SAME: { -// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1] -// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1] -// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul -// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : -// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add -// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] : -// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : -// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] : -// CHECK: } -// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : -// CHECK: } -// CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 : - -// ----- - module { func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { %c0 = arith.constant 0 : index @@ -599,8 +476,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2 - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -662,9 +541,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1 - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -733,8 +613,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1 - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 7380b766935ff..45d6ae3820159 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -169,10 +169,10 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. template -static LogicalResult -applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, - Range &&payloadOps, uint32_t numConsumerToFuse, - TransformResults &transformResults) { +static LogicalResult applyFuseConsumer( + RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, + MutableArrayRef loops, uint32_t numConsumerToFuse, + TransformResults &transformResults) { SmallVector originalConsumerOps; SmallVector fusedConsumerOps; @@ -181,7 +181,7 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, while (numConsumerToFuse--) { FailureOr fuseConsumerResults = - scf::tileAndFuseConsumerOfSlice(rewriter, target); + scf::tileAndFuseConsumerOfSlice(rewriter, target, loops); if (failed(fuseConsumerResults)) return failure(); @@ -203,8 +203,17 @@ DiagnosedSilenceableFailure transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { + SmallVector loops; + for (auto op : llvm::reverse(getLoops())) { + auto loopLikeOp = + dyn_cast(*state.getPayloadOps(op).begin()); + if (!loopLikeOp) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + loops.push_back(loopLikeOp); + } LogicalResult result = applyFuseConsumer( - rewriter, getOperation(), state.getPayloadOps(getTarget()), + rewriter, getOperation(), state.getPayloadOps(getTarget()), loops, getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); @@ -213,6 +222,7 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, void transform::TestFuseConsumerOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); + consumesHandle(getLoopsMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 34b075a5c17f9..98f7145c99cb1 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -58,14 +58,16 @@ def TestFuseConsumerOp : Op:$num_consumer_to_fuse); + let arguments = (ins + TransformHandleTypeInterface:$target, + Variadic:$loops, + DefaultValuedAttr:$num_consumer_to_fuse); let results = (outs TransformHandleTypeInterface:$consumer, TransformHandleTypeInterface:$fused_consumer); let assemblyFormat = [{ - $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? + $target `in` `(` $loops `)` + (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? attr-dict `:` functional-type(operands, results) }]; }