diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index a997502c34299..f9036cf96e9a1 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -295,18 +295,23 @@ def FuseOp : Op:$tile_sizes, - DefaultValuedAttr:$tile_interchange); + DefaultValuedAttr:$tile_interchange, + DefaultValuedAttr:$apply_cleanup); let results = (outs TransformHandleTypeInterface:$transformed, Variadic:$loops); let assemblyFormat = [{ $target ($tile_sizes^)? (`interchange` $tile_interchange^)? - attr-dict `:` functional-type(operands, results) + (`apply_cleanup` `=` $apply_cleanup^)? attr-dict + `:` functional-type(operands, results) }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 77c812cde7153..9f5f9f3fca97a 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include @@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions { fusionControlFn = controlFn; return *this; } + + /// An optional set of rewrite patterns to apply to the results of tiling + /// before fusion. This will track deleted and newly inserted + /// `tensor.extract_slice` ops and update the worklist. + std::optional cleanupPatterns = std::nullopt; }; /// Fuse the producer of the source of `candidateSliceOp` by computing the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 0b9223013a0f1..8e7621754f76b 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -562,6 +562,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; + + if (getApplyCleanup()) { + MLIRContext *context = rewriter.getContext(); + RewritePatternSet patterns(context); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tileAndFuseOptions.cleanupPatterns = std::move(patterns); + } + LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 50cfd29e6bf90..e2feb10b31454 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -24,6 +24,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -1315,6 +1317,104 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( return generatedSlices; } +namespace { + +//===----------------------------------------------------------------------===// +// SliceTrackingListener +//===----------------------------------------------------------------------===// + +/// This class is a listener for tracking the insertion and removal of +/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy +/// fusion algorithm to apply cleanup patterns in between fusion steps. +class SliceTrackingListener : public RewriterBase::Listener { +public: + explicit SliceTrackingListener( + std::optional patterns); + SliceTrackingListener() = default; + + /// Adds the given list of operations to the worklist, and if present, applies + /// the list of `patterns` to the newly added operations. This only processes + /// the given operations and any newly inserted ones by the pattern set. + LogicalResult insertAndApplyPatterns(ArrayRef newOps); + + /// Add to the new operation worklist if it is an extract_slice. + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override; + + /// Shared helper for operation removal from the worklist. + void removeOp(Operation *op); + + /// Remove the operation from the worklist. + void notifyOperationErased(Operation *op) override; + + /// Remove the operation from the worklist. + void notifyOperationReplaced(Operation *op, ValueRange replacement) override; + + /// The worklist for this transformation keeps track of the slices to visit + /// next for fusion. + std::deque worklist; + +private: + /// Optional pattern set to apply when adding new operations to the worklist. + std::optional patterns = std::nullopt; +}; + +SliceTrackingListener::SliceTrackingListener( + std::optional p) { + patterns = std::move(p); +} + +LogicalResult +SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { + for (Operation *op : ops) { + if (auto slice = dyn_cast(op)) + worklist.push_back(slice); + } + + if (!patterns) + return success(); + + GreedyRewriteConfig config; + config.listener = this; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + return applyOpPatternsAndFold(ops, patterns.value(), config); +} + +void SliceTrackingListener::notifyOperationInserted( + Operation *op, OpBuilder::InsertPoint previous) { + auto slice = dyn_cast(op); + if (!slice) + return; + worklist.push_back(slice); +} + +// Scan the worklist for the given op and remove it if present. The expectation +// is for the worklist to be small and for removal to be relatively rare. +void SliceTrackingListener::removeOp(Operation *op) { + if (!isa(op)) + return; + auto iter = worklist.begin(); + while (iter != worklist.end()) { + if (*iter == op) + break; + iter++; + } + if (iter == worklist.end()) + return; + + worklist.erase(iter); +} + +void SliceTrackingListener::notifyOperationErased(Operation *op) { + removeOp(op); +} + +void SliceTrackingListener::notifyOperationReplaced(Operation *op, + ValueRange replacement) { + removeOp(op); +} +} // namespace + /// Implementation of tile consumer and fuse producer greedily. FailureOr mlir::scf::tileConsumerAndFuseProducersUsingSCF( @@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( tensor::ExtractSliceOp candidateSlice; SCFTileAndFuseOptions::ControlFnResult controlFnResult; }; - std::deque worklist; - auto addCandidateSlices = [&worklist, &options, - &loops](ArrayRef candidates) { - for (auto candidate : candidates) { - auto sliceOp = dyn_cast(candidate); - if (!sliceOp || sliceOp.use_empty()) - continue; - auto [fusableProducer, destinationInitArg] = - getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops); - if (!fusableProducer) - continue; - std::optional controlFnResult = - options.fusionControlFn(sliceOp, fusableProducer, - destinationInitArg.has_value()); - if (!controlFnResult) - continue; - worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()}); - } - }; + SliceTrackingListener sliceTracker = + SliceTrackingListener(options.cleanupPatterns); - addCandidateSlices(tilingResult->generatedSlices); + if (failed( + sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { + return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); + } OpBuilder::InsertionGuard g(rewriter); - while (!worklist.empty()) { - // Traverse the slices in BFS fashion. - WorklistItem worklistItem = worklist.front(); - worklist.pop_front(); + while (!sliceTracker.worklist.empty()) { + auto candidateSlice = sliceTracker.worklist.front(); + sliceTracker.worklist.pop_front(); + + auto [fusableProducer, destinationInitArg] = + getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), + loops); + if (!fusableProducer) + continue; + + std::optional controlFnResult = + options.fusionControlFn(candidateSlice, fusableProducer, + destinationInitArg.has_value()); + if (!controlFnResult) + continue; + + WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. @@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( if (!fusedResult) continue; + SmallVector worklistCandidates = fusedResult->generatedSlices; + if (worklistItem.controlFnResult.yieldProducerReplacement) { // Reconstruct and yield all opResult of fusableProducerOp by default. The // caller can specific which one to yield by designating optional argument @@ -1421,7 +1522,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( fusableProducerOp, "failed to replacement value for this " "operation from within the tiled loop"); } - addCandidateSlices(newSlices.value()); + worklistCandidates.append(newSlices.value()); for (auto [index, result] : llvm::enumerate(fusableProducerOp->getResults())) { origValToResultNumber[result] = loops.front()->getNumResults() - @@ -1429,12 +1530,15 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( index; } } - addCandidateSlices(fusedResult->generatedSlices); if (Operation *tiledAndFusedOp = fusedResult->tiledAndFusedProducer.getDefiningOp()) { fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); tiledAndFusedOps.insert(tiledAndFusedOp); } + + if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { + return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); + } } DenseMap replacements; diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 3a023deb1132f..ac1ca9319d335 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -178,3 +178,103 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @fuse_through_slice +func.func @fuse_through_slice(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]] + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg0: tensor) -> tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %2 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain +func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]] + %0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>) + outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32> + %1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32> + %2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32> + %3 = tensor.cast %2 : tensor<98x98xf32> to tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %5 = linalg.elemwise_binary ins(%4, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %5 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unrelated_slice +func.func @fuse_unrelated_slices(%arg0: tensor, %arg1: tensor) -> (tensor, tensor<10x10xf32>) { + + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]] + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]], %[[SLICE2]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor to tensor<10x10xf32> + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg0: tensor) -> tensor + %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %2, %slice2 : tensor, tensor<10x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +}