From 66fc41978f0f479c46b8be3ec3a70f95d74a838f Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 19 Mar 2025 11:36:08 -0700 Subject: [PATCH 1/2] [mlir][TilingInterface] Make `tileAndFuseConsumerOfSlice` take surrounding loops as an argument. This gets the consumer fusion method in sync with the corresponding producer fusion method `tileAndFuseProducerOfSlice`. Not taking this as input required use of complicated analysis to retrieve the surrounding loops which are very fragile. Just like the producer fusion method, the loops need to be taken in as an argument, with typically the loops being created by the tiling methods. Some utilities are added to check that the loops passed in are perfectly nested (in the case of an `scf.for` loop nest. This is change 1 of N to simplify the implementation of tile and fuse consumers. Signed-off-by: MaheshRavishankar --- .../SCF/Transforms/TileUsingInterface.h | 3 +- .../SCF/Transforms/TileUsingInterface.cpp | 152 ++++++++++++------ 2 files changed, 107 insertions(+), 48 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index d2cddfe00ac78..33a43ce2ee7bb 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult { SmallVector tiledOps; }; FailureOr -tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); +tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, + MutableArrayRef loops); /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index af87fb7a79d04..4fd10b0e30ab0 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { return {nestLoops.rbegin(), nestLoops.rend()}; } +/// Check that the loop is perfectly nested. +static bool +isPerfectlyNestedForLoops(MutableArrayRef loops) { + assert(!loops.empty() && "unexpected empty loop nest"); + if (loops.size() == 1) { + return isa_and_nonnull(loops.front().getOperation()); + } + for (auto [outerLoop, innerLoop] : + llvm::zip_equal(loops.drop_back(), loops.drop_front())) { + auto outerFor = dyn_cast_or_null(outerLoop.getOperation()); + auto innerFor = dyn_cast_or_null(innerLoop.getOperation()); + if (!outerFor || !innerFor) { + return false; + } + auto outerBBArgs = outerFor.getRegionIterArgs(); + auto innerIterArgs = innerFor.getInitArgs(); + if (outerBBArgs.size() != innerIterArgs.size()) { + return false; + } + + for (auto [outerBBArg, innerIterArg] : + llvm::zip(outerBBArgs, innerIterArgs)) { + if (!llvm::hasSingleElement(outerBBArg.getUses()) || + innerIterArg != outerBBArg) { + return false; + } + } + + auto outerYields = + cast(outerFor.getBody()->getTerminator())->getOperands(); + auto innerResults = innerFor.getResults(); + if (outerYields.size() != innerResults.size()) { + return false; + } + for (auto [outerYield, innerResult] : + llvm::zip(outerYields, innerResults)) { + if (!llvm::hasSingleElement(innerResult.getUses()) || + outerYield != innerResult) { + return false; + } + } + } + return true; +} + /// Fetch the untiled consumer of a scf.for's result which is yielded by a /// tensor.insert_slice. This function makes the following assumptions : /// 1. tensor.insert_slice has scf.yield as its only user. /// 2. scf.for's corresponding result has only one use. static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, - tensor::InsertSliceOp candidateSliceOp) { + tensor::InsertSliceOp candidateSliceOp, + MutableArrayRef loops) { + assert(!loops.empty() && "unexpected loops to be empty"); + // 1. Expect slice to be part of the body of the inner most loop. + Operation *containingOp = candidateSliceOp->getParentOp(); + if (containingOp != loops.back()) { + return rewriter.notifyMatchFailure( + candidateSliceOp, + "expected slice to be within body of inner-most loop"); + } + + if (!isPerfectlyNestedForLoops(loops)) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "expected passed loops to be perfectly nested."); + } + if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) return failure(); Value sliceResult = candidateSliceOp.getResult(); // Step 1. Fetch the corresponding output. OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); unsigned resultNumber = yieldOpOperand.getOperandNumber(); - // Step 2. Check containing op is scf.for. - Operation *containingOp = candidateSliceOp->getParentOp(); - auto forOp = dyn_cast(containingOp); - if (!forOp) - return failure(); - scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front(); + + scf::ForOp topLevelForOp = cast(loops.front().getOperation()); return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); } @@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, /// by a tensor.parallel_insert_slice. static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, - tensor::ParallelInsertSliceOp candidateSliceOp) { - // Step 1. Fetch the corresponding output + tensor::ParallelInsertSliceOp candidateSliceOp, + MutableArrayRef loops) { + assert(!loops.empty() && "unexpected loops to be empty"); + // 1. Check that the surrounding loop is a single scf.forall loop. + if (loops.size() != 1) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "expected single surrounding scf.forall"); + } + auto forallOp = dyn_cast(loops.front().getOperation()); + if (!forallOp) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "expected single surrounding scf.forall"); + } + + // 2. Fetch the corresponding output Value sliceDest = candidateSliceOp.getDest(); auto iterArg = dyn_cast(sliceDest); if (!iterArg) return failure(); - Operation *containingOp = iterArg.getOwner()->getParentOp(); - if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) - return failure(); - // Step 2. Check that the containing op is scf.forall. - auto forallOp = dyn_cast(containingOp); - if (!forallOp) + if (iterArg.getOwner()->getParentOp() != forallOp) return failure(); + unsigned resultNumber = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) .getResultNumber(); - return getConsumerFromLoopUses(rewriter, containingOp, resultNumber); + return getConsumerFromLoopUses(rewriter, forallOp, resultNumber); } /// A utility to fetch an untiled consumer of /// tensor.insert_slice/tensor.parallel_insert_slice. static FailureOr -getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { +getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp, + MutableArrayRef loops) { + if (loops.empty()) { + return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops"); + } + if (auto insertSlice = dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(rewriter, insertSlice); + return getUntiledConsumerFromSlice(rewriter, insertSlice, loops); } else if (auto parallelInsertSlice = dyn_cast(sliceOp)) { - return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice); + return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops); } else { return failure(); } @@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) { /// Implementation of fusing consumer of a single slice by computing the /// slice of the consumer in-place for scf loop. FailureOr -mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, - Operation *candidateSliceOp) { +mlir::scf::tileAndFuseConsumerOfSlice( + RewriterBase &rewriter, Operation *candidateSliceOp, + MutableArrayRef loops) { + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return candidateSliceOp->emitOpError( + "cannot call tile and fuse consumer with an empty loop nest"); + } if (!isa( candidateSliceOp)) return failure(); - bool isInsertSliceOp = isa(candidateSliceOp); - // 1. Get the consumer of scf.for for the result yielded by // tensor.insert_slice/parallel_insert_slice. FailureOr maybeConsumerOpOperand = - getUntiledConsumerFromSlice(rewriter, candidateSliceOp); + getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops); if (failed(maybeConsumerOpOperand)) { return rewriter.notifyMatchFailure(candidateSliceOp, "could not fetch consumer to fuse"); @@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, consumerOp, "consumer op's operand doesn't seem to be an OpResult"); } - // There are two possible cases regarding `oldLoopOp` here: - // 1. single `scf.forall` or `scf.for`. - // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the - // top-level loop is the outer-most one of these nested loops. - LoopLikeOpInterface innerMostLoop = - candidateSliceOp->getParentOfType(); - SmallVector nestedLoops; - if (isInsertSliceOp) { - nestedLoops = llvm::map_to_vector( - getPerfectlyNestedLoopsOutsideOf( - cast(innerMostLoop.getOperation())), - [](scf::ForOp forOp) { - return cast(forOp.getOperation()); - }); - } else { - nestedLoops = {innerMostLoop}; - } - - LoopLikeOpInterface outerMostLoop = nestedLoops.front(); + LoopLikeOpInterface outerMostLoop = loops.front(); + LoopLikeOpInterface innerMostLoop = loops.back(); // Check assumption for loop with `reorderOperations` disabled. if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { @@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, return success(); }; // 14. Add new inits to [nested] loops. - if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits, + if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits, newYieldValuesFn))) { return rewriter.notifyMatchFailure(tiledConsumerOp, "unable to add new inits to nest loop"); @@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, // 15. Replace the result of scf loop and consumer op with new loop's // results. - for (auto &&[oldResult, newResult] : llvm::zip( - consumerOp->getResults(), - nestedLoops.front()->getResults().take_back(newInits.size()))) { + for (auto &&[oldResult, newResult] : + llvm::zip(consumerOp->getResults(), + loops.front()->getResults().take_back(newInits.size()))) { rewriter.replaceAllUsesWith(oldResult, newResult); } From 9c0d42678b1a2fe87abe269771860d3802f0b0df Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 20 Mar 2025 22:31:10 -0700 Subject: [PATCH 2/2] Address comments. Signed-off-by: MaheshRavishankar --- .../SCF/Transforms/TileUsingInterface.cpp | 67 ++---- .../transform-tile-and-fuse-pack-unpack.mlir | 4 +- .../tile-and-fuse-consumer.mlir | 196 ++++-------------- .../TestTilingInterfaceTransformOps.cpp | 22 +- .../TestTilingInterfaceTransformOps.td | 10 +- 5 files changed, 81 insertions(+), 218 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 4fd10b0e30ab0..8e407cc1b348f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1846,11 +1846,9 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, return failure(); } -/// Find the perfectly nested loops outside of given loop(included) sorted -/// from outer to inner. -/// -/// E.g. -/// +/// Check that the loop is perfectly nested. +/// The loops are expected to be ordered from outer most to inner most. +/// For example: /// ``` /// %0 = scf.for() /// %1 = scf.for() @@ -1860,37 +1858,7 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, /// yield %2 /// yield %1 /// ``` -/// -/// This function will return three perfectly nested loops: %0 + %1 + %2, when -/// target inner loop is %2. -static SmallVector -getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) { - SmallVector nestLoops = {loop}; - auto outerLoop = dyn_cast(loop->getParentOp()); - - // Check if it is the ForOp that yield the result of inner loop. - auto isForOpYieldResultOfInnerLoop = - [](scf::ForOp outerLoop) -> LogicalResult { - Block *body = outerLoop.getBody(); - if (!llvm::hasSingleElement(body->without_terminator())) - return failure(); - auto yieldOp = cast(body->getTerminator()); - auto innerForOp = dyn_cast(body->front()); - if (!innerForOp) - return failure(); - // All of innerForOp results should be yielded. - return success(innerForOp->getNumResults() == yieldOp->getNumOperands()); - }; - - while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) { - nestLoops.push_back(outerLoop); - outerLoop = dyn_cast(outerLoop->getParentOp()); - } - // sorted from outer to inner - return {nestLoops.rbegin(), nestLoops.rend()}; -} - -/// Check that the loop is perfectly nested. +/// Here loops should be [%0, %1]. static bool isPerfectlyNestedForLoops(MutableArrayRef loops) { assert(!loops.empty() && "unexpected empty loop nest"); @@ -1911,21 +1879,21 @@ isPerfectlyNestedForLoops(MutableArrayRef loops) { } for (auto [outerBBArg, innerIterArg] : - llvm::zip(outerBBArgs, innerIterArgs)) { + llvm::zip_equal(outerBBArgs, innerIterArgs)) { if (!llvm::hasSingleElement(outerBBArg.getUses()) || innerIterArg != outerBBArg) { return false; } } - auto outerYields = + ValueRange outerYields = cast(outerFor.getBody()->getTerminator())->getOperands(); - auto innerResults = innerFor.getResults(); + ValueRange innerResults = innerFor.getResults(); if (outerYields.size() != innerResults.size()) { return false; } for (auto [outerYield, innerResult] : - llvm::zip(outerYields, innerResults)) { + llvm::zip_equal(outerYields, innerResults)) { if (!llvm::hasSingleElement(innerResult.getUses()) || outerYield != innerResult) { return false; @@ -1935,10 +1903,12 @@ isPerfectlyNestedForLoops(MutableArrayRef loops) { return true; } -/// Fetch the untiled consumer of a scf.for's result which is yielded by a -/// tensor.insert_slice. This function makes the following assumptions : -/// 1. tensor.insert_slice has scf.yield as its only user. -/// 2. scf.for's corresponding result has only one use. +/// Fetch the untiled consumer of the outermost scf.for's result which is +/// yielded by a tensor.insert_slice from the innermost scf.for. This function +/// makes the following assumptions : +/// 1. tensor.insert_slice has scf.yield as its only user. +/// 2. scf.for's corresponding result has only one use. +/// 3. The `loops` passed in are perfectly nested `scf.for` operations. static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, @@ -1952,6 +1922,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, "expected slice to be within body of inner-most loop"); } + // 2. Check that the loop is perfectly nested. if (!isPerfectlyNestedForLoops(loops)) { return rewriter.notifyMatchFailure( candidateSliceOp, "expected passed loops to be perfectly nested."); @@ -1960,7 +1931,8 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) return failure(); Value sliceResult = candidateSliceOp.getResult(); - // Step 1. Fetch the corresponding output. + + // 3. Fetch the corresponding output. OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); unsigned resultNumber = yieldOpOperand.getOperandNumber(); @@ -2007,10 +1979,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp, MutableArrayRef loops) { - if (loops.empty()) { - return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops"); - } - + assert(!loops.empty() && "unexpected empty loops"); if (auto insertSlice = dyn_cast(sliceOp)) { return getUntiledConsumerFromSlice(rewriter, insertSlice, loops); } else if (auto parallelInsertSlice = diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 5d4ae4f15d3fd..185fb9b358055 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,8 +170,8 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op - : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op) + transform.test.fuse_consumer %slice_op in (%forall_op) + : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 8ce05d94c4ad0..77e52946b830f 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -26,10 +26,12 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -83,11 +85,13 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op %first_slice_op, %second_slice_op = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -153,8 +157,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -220,11 +226,13 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op %first_slice_op, %second_slice_op = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -287,8 +295,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -348,8 +358,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -409,8 +421,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -437,143 +451,6 @@ module attributes {transform.with_named_sequence} { // ----- -module { - func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %dest0 = tensor.empty() : tensor<256x256xf32> - %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> - %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) { - %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) { - %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32> - %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32> - %extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32> - %3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32> - %insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32> - scf.yield %insert_slice : tensor<256x256xf32> - } - scf.yield %2 : tensor<256x256xf32> - } - %4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> - return %4 : tensor<256x256xf32> - } -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: func.func @fuse_add_consumer_into_nested_scf_for( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> -// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[dest1:.*]] = linalg.fill -// CHECK-SAME: outs(%[[dest0]] : -// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]]) -// CHECK-SAME: { -// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]]) -// CHECK-SAME: { -// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1] -// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1] -// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul -// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : -// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add -// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] : -// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : -// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] : -// CHECK: } -// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : -// CHECK: } -// CHECK: return %[[LOOP_RESULT1]]#1 : - -// ----- - -// This test case checks fusion of consumer even if the producer has multiple uses. -// The multiple uses of the producer essentially means that besides the consumer -// op in concern, the only other uses of the producer are allowed in :- -// 1. scf.yield -// 2. tensor.parallel_insert_slice - -module { - module { - func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<256x256xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> - %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) { - %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) { - %extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32> - %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32> - %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32> - %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32> - %inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32> - scf.yield %inserted_slice : tensor<256x256xf32> - } - %4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32> - scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32> - } - return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32> - } - } - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } - } -} -// CHECK: func.func @fuse_consumer_for_multi_use_producer( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32> -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> -// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[dest1:.*]] = linalg.fill -// CHECK-SAME: outs(%[[dest0]] : -// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]]) -// CHECK-SAME: { -// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]] -// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]]) -// CHECK-SAME: { -// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1] -// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1] -// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul -// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : -// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add -// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] : -// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : -// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1] -// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] : -// CHECK: } -// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 : -// CHECK: } -// CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 : - -// ----- - module { func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { %c0 = arith.constant 0 : index @@ -599,8 +476,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2 - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -662,9 +541,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1 - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -733,8 +613,10 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1 - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 7380b766935ff..45d6ae3820159 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -169,10 +169,10 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. template -static LogicalResult -applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, - Range &&payloadOps, uint32_t numConsumerToFuse, - TransformResults &transformResults) { +static LogicalResult applyFuseConsumer( + RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, + MutableArrayRef loops, uint32_t numConsumerToFuse, + TransformResults &transformResults) { SmallVector originalConsumerOps; SmallVector fusedConsumerOps; @@ -181,7 +181,7 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, while (numConsumerToFuse--) { FailureOr fuseConsumerResults = - scf::tileAndFuseConsumerOfSlice(rewriter, target); + scf::tileAndFuseConsumerOfSlice(rewriter, target, loops); if (failed(fuseConsumerResults)) return failure(); @@ -203,8 +203,17 @@ DiagnosedSilenceableFailure transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state) { + SmallVector loops; + for (auto op : llvm::reverse(getLoops())) { + auto loopLikeOp = + dyn_cast(*state.getPayloadOps(op).begin()); + if (!loopLikeOp) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + loops.push_back(loopLikeOp); + } LogicalResult result = applyFuseConsumer( - rewriter, getOperation(), state.getPayloadOps(getTarget()), + rewriter, getOperation(), state.getPayloadOps(getTarget()), loops, getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); @@ -213,6 +222,7 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, void transform::TestFuseConsumerOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); + consumesHandle(getLoopsMutable(), effects); producesHandle(getOperation()->getOpResults(), effects); modifiesPayload(effects); } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 34b075a5c17f9..98f7145c99cb1 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -58,14 +58,16 @@ def TestFuseConsumerOp : Op:$num_consumer_to_fuse); + let arguments = (ins + TransformHandleTypeInterface:$target, + Variadic:$loops, + DefaultValuedAttr:$num_consumer_to_fuse); let results = (outs TransformHandleTypeInterface:$consumer, TransformHandleTypeInterface:$fused_consumer); let assemblyFormat = [{ - $target (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? + $target `in` `(` $loops `)` + (`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)? attr-dict `:` functional-type(operands, results) }]; }