Skip to content

Commit acc5603

Browse files
authored
[FXML-5890] Order tiling worklist (#532)
1 parent 065d0c0 commit acc5603

File tree

5 files changed

+210
-12
lines changed

5 files changed

+210
-12
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
1111

1212
#include "mlir/Dialect/SCF/IR/SCF.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/Interfaces/LoopLikeInterface.h"
@@ -194,6 +195,21 @@ struct SCFTileAndFuseOptions {
194195
/// before fusion. This will track deleted and newly inserted
195196
/// `tensor.extract_slice` ops and update the worklist.
196197
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
198+
199+
/// A function to insert a tilable node into a list of nodes to be tiled.
200+
/// This controls the order in which tiling and fusion happen.
201+
using WorklistInsertFnTy = std::function<void(
202+
tensor::ExtractSliceOp op, std::deque<tensor::ExtractSliceOp> &worklist)>;
203+
/// By default, simply append the op at the end of the queue.
204+
WorklistInsertFnTy worklistInsertFn =
205+
[](tensor::ExtractSliceOp op,
206+
std::deque<tensor::ExtractSliceOp> &worklist) {
207+
worklist.push_back(op);
208+
};
209+
SCFTileAndFuseOptions &setWorklistInsertFn(WorklistInsertFnTy insertFn) {
210+
worklistInsertFn = insertFn;
211+
return *this;
212+
}
197213
};
198214

199215
/// Fuse the producer of the source of `candidateSliceOp` by computing the

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,7 +1391,8 @@ namespace {
13911391
class SliceTrackingListener : public RewriterBase::Listener {
13921392
public:
13931393
explicit SliceTrackingListener(
1394-
std::optional<FrozenRewritePatternSet> patterns);
1394+
std::optional<FrozenRewritePatternSet> patterns,
1395+
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn);
13951396
SliceTrackingListener() = default;
13961397

13971398
/// Adds the given list of operations to the worklist, and if present,
@@ -1421,18 +1422,22 @@ class SliceTrackingListener : public RewriterBase::Listener {
14211422
/// Optional pattern set to apply when adding new operations to the
14221423
/// worklist.
14231424
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1425+
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn;
14241426
};
14251427

14261428
SliceTrackingListener::SliceTrackingListener(
1427-
std::optional<FrozenRewritePatternSet> p) {
1429+
std::optional<FrozenRewritePatternSet> p,
1430+
scf::SCFTileAndFuseOptions::WorklistInsertFnTy w) {
14281431
patterns = std::move(p);
1432+
worklistInsertFn = w;
14291433
}
14301434

1435+
/// Insert extract_slice ops into the worklist.
14311436
LogicalResult
14321437
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
14331438
for (Operation *op : ops) {
14341439
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1435-
worklist.push_back(slice);
1440+
worklistInsertFn(slice, worklist);
14361441
}
14371442

14381443
if (!patterns)
@@ -1444,12 +1449,14 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
14441449
return applyOpPatternsGreedily(ops, patterns.value(), config);
14451450
}
14461451

1452+
/// Insert extract_slice ops created by cleanup patterns into the worklist.
1453+
/// Triggered from applyOpPatternsAndFold() above.
14471454
void SliceTrackingListener::notifyOperationInserted(
14481455
Operation *op, OpBuilder::InsertPoint previous) {
14491456
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
14501457
if (!slice)
14511458
return;
1452-
worklist.push_back(slice);
1459+
worklistInsertFn(slice, worklist);
14531460
}
14541461

14551462
// Scan the worklist for the given op and remove it if present. The
@@ -1580,7 +1587,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15801587
};
15811588

15821589
SliceTrackingListener sliceTracker =
1583-
SliceTrackingListener(options.cleanupPatterns);
1590+
SliceTrackingListener(options.cleanupPatterns, options.worklistInsertFn);
15841591

15851592
if (failed(
15861593
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
@@ -1596,6 +1603,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
15961603
loops);
15971604
if (!fusableProducer)
15981605
continue;
1606+
LLVM_DEBUG(llvm::dbgs() << "worklist: producer is "
1607+
<< *(fusableProducer.getOwner()) << "\n");
15991608

16001609
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
16011610
options.fusionControlFn(candidateSlice, fusableProducer,
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// RUN: mlir-opt %s -transform-interpreter -split-input-file -debug-only=tile-using-interface 2>&1 | FileCheck %s
2+
3+
func.func @tile_order_ceil_then_negf(%arg: tensor<256xf32>) -> tensor<256xf32> {
4+
// Ops are tiled by lower priority: linalg.powf, linalg.ceil (1st operand of powf, priority = 0),
5+
// linalg.negf (2nd operand of powf, priority = 1), linalg.ceil (operand of negf, priority = 0)
6+
%empty = tensor.empty() : tensor<256xf32>
7+
%0 = linalg.ceil {tiling_priority = 0 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
8+
%empty1 = tensor.empty() : tensor<256xf32>
9+
%1 = linalg.negf {tiling_priority = 1 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
10+
%empty2 = tensor.empty() : tensor<256xf32>
11+
%2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>
12+
13+
// The order of these checks is the order in which the ops are actually tiled.
14+
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
15+
// CHECK: worklist: producer is %{{.*}} = linalg.negf
16+
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
17+
18+
return %2 : tensor<256xf32>
19+
}
20+
21+
module attributes {transform.with_named_sequence} {
22+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
23+
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
24+
%1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
25+
transform.yield
26+
}
27+
}
28+
29+
// -----
30+
31+
func.func @tile_negf_then_ceil(%arg: tensor<256xf32>) -> tensor<256xf32> {
32+
// Ops are tiled by lower priority: linalg.powf, linalg.negf (2nd operand of powf, priority = 0),
33+
// linalg.ceil (1st oprand of powf, priority = 1), linalg.ceil (operand of negf, priority = 1)
34+
%empty = tensor.empty() : tensor<256xf32>
35+
%0 = linalg.ceil {tiling_priority = 1 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
36+
%empty1 = tensor.empty() : tensor<256xf32>
37+
%1 = linalg.negf {tiling_priority = 0 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
38+
%empty2 = tensor.empty() : tensor<256xf32>
39+
%2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>
40+
41+
// CHECK: worklist: producer is %{{.*}} = linalg.negf
42+
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
43+
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
44+
45+
return %2 : tensor<256xf32>
46+
}
47+
48+
module attributes {transform.with_named_sequence} {
49+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
50+
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
51+
%1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
52+
transform.yield
53+
}
54+
}
55+
56+
// -----
57+
58+
func.func @tile_negf_then_ceil_swap_in_powf(%arg: tensor<256xf32>) -> tensor<256xf32> {
59+
// This gives the same tiling order as above regardless of the operand order in the powf
60+
// linalg.powf, linalg.negf (1st operand of powf, priority = 0),
61+
// linalg.ceil (2nd oprand of powf, priority = 1), linalg.ceil (operand of negf, priority = 1)
62+
%empty = tensor.empty() : tensor<256xf32>
63+
%0 = linalg.ceil {tiling_priority = 1 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
64+
%empty1 = tensor.empty() : tensor<256xf32>
65+
%1 = linalg.negf {tiling_priority = 0 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
66+
%empty2 = tensor.empty() : tensor<256xf32>
67+
%2 = linalg.powf {tile} ins(%1, %0: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>
68+
69+
// CHECK: worklist: producer is %{{.*}} = linalg.negf
70+
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
71+
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
72+
73+
return %2 : tensor<256xf32>
74+
}
75+
76+
module attributes {transform.with_named_sequence} {
77+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
78+
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
79+
%1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
80+
transform.yield
81+
}
82+
}
83+

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,18 @@
1414
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1515
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1616
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
1819
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1920
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2021
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2122
#include "mlir/IR/Dominance.h"
23+
#include "mlir/IR/OpDefinition.h"
2224
#include "mlir/IR/OpImplementation.h"
2325
#include "mlir/Interfaces/TilingInterface.h"
26+
#include <deque>
27+
#include <functional>
28+
#include <string>
2429

2530
#define GET_OP_CLASSES
2631
#include "TestTilingInterfaceTransformOps.h.inc"
@@ -54,12 +59,13 @@ static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
5459
/// Apply a tile and fuse transformation to all payload ops and store both the
5560
/// tiled operation as well as the created tile loops.
5661
template <typename Range>
57-
static LogicalResult
58-
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
59-
Range &&payloadOps, unsigned numLoops,
60-
ArrayRef<OpFoldResult> tileSizes,
61-
ArrayRef<int64_t> interchange, bool useForall,
62-
TransformResults &transformResults) {
62+
static LogicalResult applyTileAndFuseToAll(
63+
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
64+
unsigned numLoops, ArrayRef<OpFoldResult> tileSizes,
65+
ArrayRef<int64_t> interchange, bool useForall,
66+
TransformResults &transformResults,
67+
std::optional<scf::SCFTileAndFuseOptions::WorklistInsertFnTy>
68+
insertIntoWorklist) {
6369
SmallVector<Operation *> tiledOps;
6470
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
6571

@@ -87,6 +93,9 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
8793
}
8894

8995
scf::SCFTileAndFuseOptions tileAndFuseOptions;
96+
if (insertIntoWorklist.has_value()) {
97+
tileAndFuseOptions.setWorklistInsertFn(*insertIntoWorklist);
98+
}
9099
tileAndFuseOptions.setTilingOptions(tilingOptions);
91100

92101
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
@@ -157,7 +166,65 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
157166
LogicalResult result = applyTileAndFuseToAll(
158167
rewriter, getOperation(), state.getPayloadOps(getTarget()),
159168
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
160-
tileInterchange, getUseForall(), transformResults);
169+
tileInterchange, getUseForall(), transformResults, {});
170+
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
171+
: DiagnosedSilenceableFailure::success();
172+
}
173+
174+
//===----------------------------------------------------------------------===//
175+
// TestFuseOrderedOp
176+
//===----------------------------------------------------------------------===//
177+
178+
static std::optional<int64_t>
179+
getProducerTilingPriority(tensor::ExtractSliceOp op) {
180+
auto *producer = op.getSource().getDefiningOp();
181+
if (!producer)
182+
return {};
183+
184+
if (!producer->hasAttrOfType<IntegerAttr>("tiling_priority"))
185+
return {};
186+
187+
auto attr = producer->getAttrOfType<IntegerAttr>("tiling_priority");
188+
return attr.getInt();
189+
}
190+
191+
static void
192+
insertIntoWorklistOrdered(tensor::ExtractSliceOp op,
193+
std::deque<tensor::ExtractSliceOp> &worklist) {
194+
std::optional<int64_t> opTilingOrder = getProducerTilingPriority(op);
195+
if (!opTilingOrder) {
196+
worklist.push_back(op);
197+
return;
198+
}
199+
200+
auto iterator = worklist.begin();
201+
for (; iterator != worklist.end(); ++iterator) {
202+
std::optional<int64_t> otherOpTilingOrder =
203+
getProducerTilingPriority(*iterator);
204+
if (!otherOpTilingOrder || *otherOpTilingOrder > *opTilingOrder)
205+
break;
206+
}
207+
worklist.insert(iterator, op);
208+
}
209+
210+
DiagnosedSilenceableFailure
211+
transform::TestFuseOrderedOp::apply(TransformRewriter &rewriter,
212+
TransformResults &transformResults,
213+
TransformState &state) {
214+
SmallVector<int64_t> tileSizes =
215+
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
216+
SmallVector<int64_t> tileInterchange;
217+
for (size_t i = 0; i < tileSizes.size(); ++i) {
218+
tileInterchange.push_back(i);
219+
}
220+
221+
SmallVector<OpFoldResult> tileSizesOfr =
222+
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
223+
224+
LogicalResult result = applyTileAndFuseToAll(
225+
rewriter, getOperation(), state.getPayloadOps(getTarget()),
226+
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, {}, false,
227+
transformResults, insertIntoWorklistOrdered);
161228
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
162229
: DiagnosedSilenceableFailure::success();
163230
}

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
4949
}];
5050
}
5151

52+
def TestFuseOrderedOp : Op<Transform_Dialect, "test.tile_fuse_ordered",
53+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
54+
DeclareOpInterfaceMethods<TransformOpInterface>,
55+
ReportTrackingListenerFailuresOpTrait]> {
56+
let description = [{
57+
Applies tiling and fusion to the operations pointed to by the target handle,
58+
following the order given by each operation's tiling_priority attribute.
59+
60+
On success returns the tiled operations as well as generated loops. Emits
61+
a definite failure if tiling fails.
62+
}];
63+
64+
let arguments =
65+
(ins TransformHandleTypeInterface:$target,
66+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
67+
let results = (outs TransformHandleTypeInterface:$transfomed,
68+
Variadic<TransformHandleTypeInterface>:$loops);
69+
70+
let assemblyFormat = [{
71+
$target ($tile_sizes^)? attr-dict `:` functional-type(operands, results)
72+
}];
73+
}
74+
5275
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
5376
[DeclareOpInterfaceMethods<TransformOpInterface>,
5477
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

0 commit comments

Comments
 (0)