From 66748e7f049045308bca02d028d14ecc47ba96c2 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 19 Sep 2024 23:01:17 -0400 Subject: [PATCH 1/2] [mlir] Add option for a cleanup pattern set to SCF tiling helper The SCF helper for tiling an operation implementing the TilingInterface and greedily fusing consumers requires an uninterrupted chain of operations implementing the tiling interface to succeed. There can be cases with intermediate ops that don't implement the interface but have producers that could be fused if various canonicalization/simplification patterns could run in between fusion steps. This adds an option to SCFTileAndFuseOptions for a pattern set to run between fusion steps to the ops that result from fusion/tiling. Removed and newly inserted slices are tracked for continued fusion applications. See this RFC for more discussion: https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155 --- .../Linalg/TransformOps/LinalgTransformOps.td | 9 +- .../SCF/Transforms/TileUsingInterface.h | 6 + .../TransformOps/LinalgTransformOps.cpp | 9 + .../SCF/Transforms/TileUsingInterface.cpp | 225 ++++++++++++++++-- .../Dialect/Linalg/transform-op-fuse.mlir | 65 +++++ 5 files changed, 286 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index a997502c34299..f9036cf96e9a1 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -295,18 +295,23 @@ def FuseOp : Op:$tile_sizes, - DefaultValuedAttr:$tile_interchange); + DefaultValuedAttr:$tile_interchange, + DefaultValuedAttr:$apply_cleanup); let results = (outs TransformHandleTypeInterface:$transformed, Variadic:$loops); let assemblyFormat = [{ $target ($tile_sizes^)? (`interchange` $tile_interchange^)? - attr-dict `:` functional-type(operands, results) + (`apply_cleanup` `=` $apply_cleanup^)? attr-dict + `:` functional-type(operands, results) }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 77c812cde7153..9f5f9f3fca97a 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include @@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions { fusionControlFn = controlFn; return *this; } + + /// An optional set of rewrite patterns to apply to the results of tiling + /// before fusion. This will track deleted and newly inserted + /// `tensor.extract_slice` ops and update the worklist. + std::optional cleanupPatterns = std::nullopt; }; /// Fuse the producer of the source of `candidateSliceOp` by computing the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 0b9223013a0f1..8e7621754f76b 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -562,6 +562,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; + + if (getApplyCleanup()) { + MLIRContext *context = rewriter.getContext(); + RewritePatternSet patterns(context); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tileAndFuseOptions.cleanupPatterns = std::move(patterns); + } + LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 50cfd29e6bf90..110eba0356706 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -24,6 +24,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -1315,6 +1317,172 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( return generatedSlices; } +namespace { + +//===----------------------------------------------------------------------===// +// SliceWorklist +//===----------------------------------------------------------------------===// + +/// Struct for tracking the number of stale entries on the worklist and whether +/// there is a remaining valid entry. +struct EntryCount { + bool isValid = true; + unsigned count = 0; +}; + +/// A FIFO worklist of operations with efficient removal and set semantics. +/// +/// This class maintains a queue of operations and a mapping of operations to +/// positions in the vector, so that operations can be removed efficiently at +/// random. When an operation is removed, it is replaced with nullptr. Such +/// nullptr are skipped when pop'ing elements. +/// +/// This is similar to the worklist used by the GreedyPatternRewriteDriver, +/// except instead FIFO so that slices for fusion can be processed breadth +/// first. +class SliceWorklist { +public: + SliceWorklist() = default; + + /// Push an operation to the end of the worklist. This assumes that + /// the given operation is not already on the worklist. + void push(Operation *op); + + /// Pop the an operation from the end of the worklist. Returns nullptr if + /// there are no remaining valid operations. + Operation *pop(); + + /// Remove an operation from the worklist. + void remove(Operation *op); + +protected: + /// The queue of operations. + std::deque list; + + /// A mapping of operations to the number of stale copies in the queue. + DenseMap map; +}; + +void SliceWorklist::push(Operation *op) { + assert(op && "cannot push nullptr to worklist"); + list.push_back(op); + EntryCount newCount = map.lookup(op); + // Because operations are only pushed on creation, valid duplicates are + // never added. + assert((!map.contains(op) || !newCount.isValid) && + "cannot push a duplicate operation"); + map[op] = {/*isValid=*/true, newCount.count + 1}; +} + +Operation *SliceWorklist::pop() { + // Pop the front of the queue until we hit a valid entry. + while (!list.empty()) { + Operation *op = list.front(); + list.pop_front(); + + EntryCount e = map.lookup(op); + // If the entry count is greater than 1 or there is no valid entry, + // this must be a stale entry. Decrement the map entry by one and continue. + if (e.count > 1 || !e.isValid) { + int64_t newCount = e.count - 1; + if (newCount <= 0) + map.erase(op); + else + map[op] = {e.isValid, static_cast(newCount)}; + continue; + } + + map.erase(op); + return op; + } + return nullptr; +} + +// Mark the operation as invalid if present. Removal from the map will +// happen later when popping from the worklist. +void SliceWorklist::remove(Operation *op) { + if (!map.contains(op)) + return; + + EntryCount e = map.lookup(op); + map[op] = {/*isValid=*/false, e.count}; +} + +//===----------------------------------------------------------------------===// +// SliceTrackingListener +//===----------------------------------------------------------------------===// + +/// This class is a listener for tracking the insertion and removal of +/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy +/// fusion algorithm to apply cleanup patterns in between fusion steps. +class SliceTrackingListener : public RewriterBase::Listener { +public: + explicit SliceTrackingListener( + std::optional patterns); + SliceTrackingListener() = default; + + /// Adds the given list of operations to the worklist, and if present, applies + /// the list of `patterns` to the newly added operations. This only processes + /// the given operations and any newly inserted ones by the pattern set. + LogicalResult insertAndApplyPatterns(ArrayRef newOps); + + /// Add to the new operation worklist if it is an extract_slice. + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override; + + /// Remove the operation from the worklist. + void notifyOperationErased(Operation *op) override; + + /// Remove the operation from the worklist. + void notifyOperationReplaced(Operation *op, ValueRange replacement) override; + + /// The worklist for this transformation keeps track of the operations that + /// need to be (re)visited. + SliceWorklist worklist; + +private: + /// Optional pattern set to apply when adding new operations to the worklist. + std::optional patterns = std::nullopt; +}; + +SliceTrackingListener::SliceTrackingListener( + std::optional p) { + patterns = std::move(p); +} + +LogicalResult +SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { + for (Operation *op : ops) { + if (isa(op)) + worklist.push(op); + } + + if (!patterns) + return success(); + + GreedyRewriteConfig config; + config.listener = this; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + return applyOpPatternsAndFold(ops, patterns.value(), config); +} + +void SliceTrackingListener::notifyOperationInserted( + Operation *op, OpBuilder::InsertPoint previous) { + if (!isa(op)) + return; + worklist.push(op); +} + +void SliceTrackingListener::notifyOperationErased(Operation *op) { + worklist.remove(op); +} + +void SliceTrackingListener::notifyOperationReplaced(Operation *op, + ValueRange replacement) { + worklist.remove(op); +} +} // namespace + /// Implementation of tile consumer and fuse producer greedily. FailureOr mlir::scf::tileConsumerAndFuseProducersUsingSCF( @@ -1370,33 +1538,33 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( tensor::ExtractSliceOp candidateSlice; SCFTileAndFuseOptions::ControlFnResult controlFnResult; }; - std::deque worklist; - auto addCandidateSlices = [&worklist, &options, - &loops](ArrayRef candidates) { - for (auto candidate : candidates) { - auto sliceOp = dyn_cast(candidate); - if (!sliceOp || sliceOp.use_empty()) - continue; - auto [fusableProducer, destinationInitArg] = - getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops); - if (!fusableProducer) - continue; - std::optional controlFnResult = - options.fusionControlFn(sliceOp, fusableProducer, - destinationInitArg.has_value()); - if (!controlFnResult) - continue; - worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()}); - } - }; + SliceTrackingListener sliceTracker = + SliceTrackingListener(options.cleanupPatterns); - addCandidateSlices(tilingResult->generatedSlices); + if (failed( + sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { + return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); + } OpBuilder::InsertionGuard g(rewriter); - while (!worklist.empty()) { - // Traverse the slices in BFS fashion. - WorklistItem worklistItem = worklist.front(); - worklist.pop_front(); + while (Operation *next = sliceTracker.worklist.pop()) { + auto candidateSlice = dyn_cast(next); + if (!candidateSlice) + continue; + + auto [fusableProducer, destinationInitArg] = + getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), + loops); + if (!fusableProducer) + continue; + + std::optional controlFnResult = + options.fusionControlFn(candidateSlice, fusableProducer, + destinationInitArg.has_value()); + if (!controlFnResult) + continue; + + WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. @@ -1407,6 +1575,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( if (!fusedResult) continue; + SmallVector worklistCandidates = fusedResult->generatedSlices; + if (worklistItem.controlFnResult.yieldProducerReplacement) { // Reconstruct and yield all opResult of fusableProducerOp by default. The // caller can specific which one to yield by designating optional argument @@ -1421,7 +1591,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( fusableProducerOp, "failed to replacement value for this " "operation from within the tiled loop"); } - addCandidateSlices(newSlices.value()); + worklistCandidates.append(newSlices.value()); for (auto [index, result] : llvm::enumerate(fusableProducerOp->getResults())) { origValToResultNumber[result] = loops.front()->getNumResults() - @@ -1429,12 +1599,15 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( index; } } - addCandidateSlices(fusedResult->generatedSlices); if (Operation *tiledAndFusedOp = fusedResult->tiledAndFusedProducer.getDefiningOp()) { fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); tiledAndFusedOps.insert(tiledAndFusedOp); } + + if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { + return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); + } } DenseMap replacements; diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 3a023deb1132f..643171e64ed4f 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -178,3 +178,68 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @fuse_through_slice +func.func @fuse_through_slice(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]] + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg0: tensor) -> tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %2 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain +func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]] + %0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>) + outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32> + %1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32> + %2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32> + %3 = tensor.cast %2 : tensor<98x98xf32> to tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %5 = linalg.elemwise_binary ins(%4, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %5 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} From be547c66dac236c4c524464ec8dcfab9d43d21fd Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 27 Sep 2024 00:17:07 -0400 Subject: [PATCH 2/2] Remove worklist class and add negative test --- .../SCF/Transforms/TileUsingInterface.cpp | 133 +++++------------- .../Dialect/Linalg/transform-op-fuse.mlir | 35 +++++ 2 files changed, 67 insertions(+), 101 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 110eba0356706..e2feb10b31454 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1319,95 +1319,6 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( namespace { -//===----------------------------------------------------------------------===// -// SliceWorklist -//===----------------------------------------------------------------------===// - -/// Struct for tracking the number of stale entries on the worklist and whether -/// there is a remaining valid entry. -struct EntryCount { - bool isValid = true; - unsigned count = 0; -}; - -/// A FIFO worklist of operations with efficient removal and set semantics. -/// -/// This class maintains a queue of operations and a mapping of operations to -/// positions in the vector, so that operations can be removed efficiently at -/// random. When an operation is removed, it is replaced with nullptr. Such -/// nullptr are skipped when pop'ing elements. -/// -/// This is similar to the worklist used by the GreedyPatternRewriteDriver, -/// except instead FIFO so that slices for fusion can be processed breadth -/// first. -class SliceWorklist { -public: - SliceWorklist() = default; - - /// Push an operation to the end of the worklist. This assumes that - /// the given operation is not already on the worklist. - void push(Operation *op); - - /// Pop the an operation from the end of the worklist. Returns nullptr if - /// there are no remaining valid operations. - Operation *pop(); - - /// Remove an operation from the worklist. - void remove(Operation *op); - -protected: - /// The queue of operations. - std::deque list; - - /// A mapping of operations to the number of stale copies in the queue. - DenseMap map; -}; - -void SliceWorklist::push(Operation *op) { - assert(op && "cannot push nullptr to worklist"); - list.push_back(op); - EntryCount newCount = map.lookup(op); - // Because operations are only pushed on creation, valid duplicates are - // never added. - assert((!map.contains(op) || !newCount.isValid) && - "cannot push a duplicate operation"); - map[op] = {/*isValid=*/true, newCount.count + 1}; -} - -Operation *SliceWorklist::pop() { - // Pop the front of the queue until we hit a valid entry. - while (!list.empty()) { - Operation *op = list.front(); - list.pop_front(); - - EntryCount e = map.lookup(op); - // If the entry count is greater than 1 or there is no valid entry, - // this must be a stale entry. Decrement the map entry by one and continue. - if (e.count > 1 || !e.isValid) { - int64_t newCount = e.count - 1; - if (newCount <= 0) - map.erase(op); - else - map[op] = {e.isValid, static_cast(newCount)}; - continue; - } - - map.erase(op); - return op; - } - return nullptr; -} - -// Mark the operation as invalid if present. Removal from the map will -// happen later when popping from the worklist. -void SliceWorklist::remove(Operation *op) { - if (!map.contains(op)) - return; - - EntryCount e = map.lookup(op); - map[op] = {/*isValid=*/false, e.count}; -} - //===----------------------------------------------------------------------===// // SliceTrackingListener //===----------------------------------------------------------------------===// @@ -1430,15 +1341,18 @@ class SliceTrackingListener : public RewriterBase::Listener { void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override; + /// Shared helper for operation removal from the worklist. + void removeOp(Operation *op); + /// Remove the operation from the worklist. void notifyOperationErased(Operation *op) override; /// Remove the operation from the worklist. void notifyOperationReplaced(Operation *op, ValueRange replacement) override; - /// The worklist for this transformation keeps track of the operations that - /// need to be (re)visited. - SliceWorklist worklist; + /// The worklist for this transformation keeps track of the slices to visit + /// next for fusion. + std::deque worklist; private: /// Optional pattern set to apply when adding new operations to the worklist. @@ -1453,8 +1367,8 @@ SliceTrackingListener::SliceTrackingListener( LogicalResult SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { for (Operation *op : ops) { - if (isa(op)) - worklist.push(op); + if (auto slice = dyn_cast(op)) + worklist.push_back(slice); } if (!patterns) @@ -1468,18 +1382,36 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { void SliceTrackingListener::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { + auto slice = dyn_cast(op); + if (!slice) + return; + worklist.push_back(slice); +} + +// Scan the worklist for the given op and remove it if present. The expectation +// is for the worklist to be small and for removal to be relatively rare. +void SliceTrackingListener::removeOp(Operation *op) { if (!isa(op)) return; - worklist.push(op); + auto iter = worklist.begin(); + while (iter != worklist.end()) { + if (*iter == op) + break; + iter++; + } + if (iter == worklist.end()) + return; + + worklist.erase(iter); } void SliceTrackingListener::notifyOperationErased(Operation *op) { - worklist.remove(op); + removeOp(op); } void SliceTrackingListener::notifyOperationReplaced(Operation *op, ValueRange replacement) { - worklist.remove(op); + removeOp(op); } } // namespace @@ -1547,10 +1479,9 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); } OpBuilder::InsertionGuard g(rewriter); - while (Operation *next = sliceTracker.worklist.pop()) { - auto candidateSlice = dyn_cast(next); - if (!candidateSlice) - continue; + while (!sliceTracker.worklist.empty()) { + auto candidateSlice = sliceTracker.worklist.front(); + sliceTracker.worklist.pop_front(); auto [fusableProducer, destinationInitArg] = getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 643171e64ed4f..ac1ca9319d335 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -243,3 +243,38 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @fuse_unrelated_slice +func.func @fuse_unrelated_slices(%arg0: tensor, %arg1: tensor) -> (tensor, tensor<10x10xf32>) { + + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]] + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]], %[[SLICE2]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor to tensor<10x10xf32> + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg0: tensor) -> tensor + %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %2, %slice2 : tensor, tensor<10x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +}