diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index d2cddfe00ac78..c54eb30842a29 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/LoopLikeInterface.h" @@ -194,6 +195,21 @@ struct SCFTileAndFuseOptions { /// before fusion. This will track deleted and newly inserted /// `tensor.extract_slice` ops and update the worklist. std::optional cleanupPatterns = std::nullopt; + + /// A function to insert a tilable node into a list of nodes to be tiled. + /// This controls the order in which tiling and fusion happen. + using WorklistInsertFnTy = std::function &worklist)>; + /// By default, simply append the op at the end of the queue. + WorklistInsertFnTy worklistInsertFn = + [](tensor::ExtractSliceOp op, + std::deque &worklist) { + worklist.push_back(op); + }; + SCFTileAndFuseOptions &setWorklistInsertFn(WorklistInsertFnTy insertFn) { + worklistInsertFn = insertFn; + return *this; + } }; /// Fuse the producer of the source of `candidateSliceOp` by computing the diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 6419ab9627925..3d25cc4e831f2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1391,7 +1391,8 @@ namespace { class SliceTrackingListener : public RewriterBase::Listener { public: explicit SliceTrackingListener( - std::optional patterns); + std::optional patterns, + scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn); SliceTrackingListener() = default; /// Adds the given list of operations to the worklist, and if present, @@ -1421,18 +1422,22 @@ class SliceTrackingListener : public RewriterBase::Listener { /// Optional pattern set to apply when adding new operations to the /// worklist. std::optional patterns = std::nullopt; + scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn; }; SliceTrackingListener::SliceTrackingListener( - std::optional p) { + std::optional p, + scf::SCFTileAndFuseOptions::WorklistInsertFnTy w) { patterns = std::move(p); + worklistInsertFn = w; } +/// Insert extract_slice ops into the worklist. LogicalResult SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { for (Operation *op : ops) { if (auto slice = dyn_cast(op)) - worklist.push_back(slice); + worklistInsertFn(slice, worklist); } if (!patterns) @@ -1444,12 +1449,14 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { return applyOpPatternsGreedily(ops, patterns.value(), config); } +/// Insert extract_slice ops created by cleanup patterns into the worklist. +/// Triggered from applyOpPatternsAndFold() above. void SliceTrackingListener::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { auto slice = dyn_cast(op); if (!slice) return; - worklist.push_back(slice); + worklistInsertFn(slice, worklist); } // Scan the worklist for the given op and remove it if present. The @@ -1580,7 +1587,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( }; SliceTrackingListener sliceTracker = - SliceTrackingListener(options.cleanupPatterns); + SliceTrackingListener(options.cleanupPatterns, options.worklistInsertFn); if (failed( sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { @@ -1596,6 +1603,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( loops); if (!fusableProducer) continue; + LLVM_DEBUG(llvm::dbgs() << "worklist: producer is " + << *(fusableProducer.getOwner()) << "\n"); std::optional controlFnResult = options.fusionControlFn(candidateSlice, fusableProducer, diff --git a/mlir/test/Dialect/Linalg/tile-sort.mlir b/mlir/test/Dialect/Linalg/tile-sort.mlir new file mode 100644 index 0000000000000..1e7e3f5b78d16 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-sort.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -debug-only=tile-using-interface 2>&1 | FileCheck %s + +func.func @tile_order_ceil_then_negf(%arg: tensor<256xf32>) -> tensor<256xf32> { + // Ops are tiled by lower priority: linalg.powf, linalg.ceil (1st operand of powf, priority = 0), + // linalg.negf (2nd operand of powf, priority = 1), linalg.ceil (operand of negf, priority = 0) + %empty = tensor.empty() : tensor<256xf32> + %0 = linalg.ceil {tiling_priority = 0 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32> + %empty1 = tensor.empty() : tensor<256xf32> + %1 = linalg.negf {tiling_priority = 1 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32> + %empty2 = tensor.empty() : tensor<256xf32> + %2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32> + + // The order of these checks is the order in which the ops are actually tiled. + // CHECK: worklist: producer is %{{.*}} = linalg.ceil + // CHECK: worklist: producer is %{{.*}} = linalg.negf + // CHECK: worklist: producer is %{{.*}} = linalg.ceil + + return %2 : tensor<256xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @tile_negf_then_ceil(%arg: tensor<256xf32>) -> tensor<256xf32> { + // Ops are tiled by lower priority: linalg.powf, linalg.negf (2nd operand of powf, priority = 0), + // linalg.ceil (1st oprand of powf, priority = 1), linalg.ceil (operand of negf, priority = 1) + %empty = tensor.empty() : tensor<256xf32> + %0 = linalg.ceil {tiling_priority = 1 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32> + %empty1 = tensor.empty() : tensor<256xf32> + %1 = linalg.negf {tiling_priority = 0 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32> + %empty2 = tensor.empty() : tensor<256xf32> + %2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32> + + // CHECK: worklist: producer is %{{.*}} = linalg.negf + // CHECK: worklist: producer is %{{.*}} = linalg.ceil + // CHECK: worklist: producer is %{{.*}} = linalg.ceil + + return %2 : tensor<256xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @tile_negf_then_ceil_swap_in_powf(%arg: tensor<256xf32>) -> tensor<256xf32> { + // This gives the same tiling order as above regardless of the operand order in the powf + // linalg.powf, linalg.negf (1st operand of powf, priority = 0), + // linalg.ceil (2nd oprand of powf, priority = 1), linalg.ceil (operand of negf, priority = 1) + %empty = tensor.empty() : tensor<256xf32> + %0 = linalg.ceil {tiling_priority = 1 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32> + %empty1 = tensor.empty() : tensor<256xf32> + %1 = linalg.negf {tiling_priority = 0 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32> + %empty2 = tensor.empty() : tensor<256xf32> + %2 = linalg.powf {tile} ins(%1, %0: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32> + + // CHECK: worklist: producer is %{{.*}} = linalg.negf + // CHECK: worklist: producer is %{{.*}} = linalg.ceil + // CHECK: worklist: producer is %{{.*}} = linalg.ceil + + return %2 : tensor<256xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 7380b766935ff..9a60b6e657613 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -14,13 +14,18 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/TilingInterface.h" +#include +#include +#include #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.h.inc" @@ -54,12 +59,13 @@ static llvm::SmallDenseSet collectTiledAndFusedOps(Operation *op) { /// Apply a tile and fuse transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template -static LogicalResult -applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, - Range &&payloadOps, unsigned numLoops, - ArrayRef tileSizes, - ArrayRef interchange, bool useForall, - TransformResults &transformResults) { +static LogicalResult applyTileAndFuseToAll( + RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, + unsigned numLoops, ArrayRef tileSizes, + ArrayRef interchange, bool useForall, + TransformResults &transformResults, + std::optional + insertIntoWorklist) { SmallVector tiledOps; SmallVector> loopOps(numLoops); @@ -87,6 +93,9 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, } scf::SCFTileAndFuseOptions tileAndFuseOptions; + if (insertIntoWorklist.has_value()) { + tileAndFuseOptions.setWorklistInsertFn(*insertIntoWorklist); + } tileAndFuseOptions.setTilingOptions(tilingOptions); scf::SCFTileAndFuseOptions::ControlFnTy controlFn = @@ -157,7 +166,65 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, LogicalResult result = applyTileAndFuseToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, - tileInterchange, getUseForall(), transformResults); + tileInterchange, getUseForall(), transformResults, {}); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// TestFuseOrderedOp +//===----------------------------------------------------------------------===// + +static std::optional +getProducerTilingPriority(tensor::ExtractSliceOp op) { + auto *producer = op.getSource().getDefiningOp(); + if (!producer) + return {}; + + if (!producer->hasAttrOfType("tiling_priority")) + return {}; + + auto attr = producer->getAttrOfType("tiling_priority"); + return attr.getInt(); +} + +static void +insertIntoWorklistOrdered(tensor::ExtractSliceOp op, + std::deque &worklist) { + std::optional opTilingOrder = getProducerTilingPriority(op); + if (!opTilingOrder) { + worklist.push_back(op); + return; + } + + auto iterator = worklist.begin(); + for (; iterator != worklist.end(); ++iterator) { + std::optional otherOpTilingOrder = + getProducerTilingPriority(*iterator); + if (!otherOpTilingOrder || *otherOpTilingOrder > *opTilingOrder) + break; + } + worklist.insert(iterator, op); +} + +DiagnosedSilenceableFailure +transform::TestFuseOrderedOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + SmallVector tileSizes = + extractFromIntegerArrayAttr(getTileSizes()); + SmallVector tileInterchange; + for (size_t i = 0; i < tileSizes.size(); ++i) { + tileInterchange.push_back(i); + } + + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); + + LogicalResult result = applyTileAndFuseToAll( + rewriter, getOperation(), state.getPayloadOps(getTarget()), + tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, {}, false, + transformResults, insertIntoWorklistOrdered); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 34b075a5c17f9..acbd9cfdea474 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -49,6 +49,29 @@ def TestFuseAndYieldOp : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Applies tiling and fusion to the operations pointed to by the target handle, + following the order given by each operation's tiling_priority attribute. + + On success returns the tiled operations as well as generated loops. Emits + a definite failure if tiling fails. + }]; + + let arguments = + (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$tile_sizes); + let results = (outs TransformHandleTypeInterface:$transfomed, + Variadic:$loops); + + let assemblyFormat = [{ + $target ($tile_sizes^)? attr-dict `:` functional-type(operands, results) + }]; +} + def TestFuseConsumerOp : Op, DeclareOpInterfaceMethods,