Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
Expand Down Expand Up @@ -194,6 +195,24 @@ struct SCFTileAndFuseOptions {
/// before fusion. This will track deleted and newly inserted
/// `tensor.extract_slice` ops and update the worklist.
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;

/// A function to insert a tilable node into a list of nodes to be tiled.
/// This controls the order in which tiling and fusion happen.
using WorklistInsertFnTy = std::function<void(
tensor::ExtractSliceOp op, std::deque<tensor::ExtractSliceOp> &worklist)>;
/// By default, simply append the op at the end of the queue.
WorklistInsertFnTy worklistInsertFn =
[](tensor::ExtractSliceOp op,
std::deque<tensor::ExtractSliceOp> &worklist) {
worklist.push_back(op);
};
SCFTileAndFuseOptions &setWorklistInsertFn(WorklistInsertFnTy insertFn) {
worklistInsertFn = insertFn;
return *this;
}
/// Emit a remark with the order in which operations are tiled.
/// This is useful to debug the worklist insert function.
bool printTilingOrder = false;
};

/// Fuse the producer of the source of `candidateSliceOp` by computing the
Expand Down
36 changes: 30 additions & 6 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>

#define DEBUG_TYPE "tile-using-interface"
Expand Down Expand Up @@ -1391,7 +1392,8 @@ namespace {
class SliceTrackingListener : public RewriterBase::Listener {
public:
explicit SliceTrackingListener(
std::optional<FrozenRewritePatternSet> patterns);
std::optional<FrozenRewritePatternSet> patterns,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn);
SliceTrackingListener() = default;

/// Adds the given list of operations to the worklist, and if present,
Expand Down Expand Up @@ -1421,18 +1423,25 @@ class SliceTrackingListener : public RewriterBase::Listener {
/// Optional pattern set to apply when adding new operations to the
/// worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn;
};

SliceTrackingListener::SliceTrackingListener(
std::optional<FrozenRewritePatternSet> p) {
std::optional<FrozenRewritePatternSet> p,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy w) {
patterns = std::move(p);
worklistInsertFn = w;
}

/// Insert extract_slice ops into the worklist.
LogicalResult
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
for (Operation *op : ops) {
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
worklist.push_back(slice);
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op)) {
LLVM_DEBUG(llvm::dbgs()
<< "worklist: insertAndApplyPatterns of " << slice << "\n");
worklistInsertFn(slice, worklist);
}
}

if (!patterns)
Expand All @@ -1444,12 +1453,16 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
return applyOpPatternsGreedily(ops, patterns.value(), config);
}

/// Insert extract_slice ops created by cleanup patterns into the worklist.
/// Triggered from applyOpPatternsAndFold() above.
void SliceTrackingListener::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
if (!slice)
return;
worklist.push_back(slice);
LLVM_DEBUG(llvm::dbgs() << "worklist: notifyOperationInserted of " << slice
<< "\n");
worklistInsertFn(slice, worklist);
}

// Scan the worklist for the given op and remove it if present. The
Expand Down Expand Up @@ -1580,22 +1593,27 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
};

SliceTrackingListener sliceTracker =
SliceTrackingListener(options.cleanupPatterns);
SliceTrackingListener(options.cleanupPatterns, options.worklistInsertFn);

if (failed(
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
}
OpBuilder::InsertionGuard g(rewriter);

unsigned tilingOrder = 0;
while (!sliceTracker.worklist.empty()) {
auto candidateSlice = sliceTracker.worklist.front();
LLVM_DEBUG(llvm::dbgs() << "worklist: popping " << candidateSlice << "\n");
sliceTracker.worklist.pop_front();

auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
loops);
if (!fusableProducer)
continue;
LLVM_DEBUG(llvm::dbgs() << "worklist: producer is "
<< *(fusableProducer.getOwner()) << "\n");

std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(candidateSlice, fusableProducer,
Expand All @@ -1614,6 +1632,12 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
if (!fusedResult)
continue;

if (options.printTilingOrder) {
auto message = llvm::formatv("Fused op in position {}", tilingOrder);
fusableProducer.getOwner()->emitRemark(message);
}
tilingOrder++;

SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;

if (worklistItem.controlFnResult.yieldProducerReplacement) {
Expand Down
87 changes: 87 additions & 0 deletions mlir/test/Dialect/Linalg/tile-sort.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics

func.func @tile_me(%arg: tensor<256xi32>) -> tensor<256xi32> {
%one = arith.constant 1 : i32
%empty = tensor.empty() : tensor<256xi32>
// expected-remark @below {{Fused op in position 0}}
// expected-remark @below {{Fused op in position 2}}
%0 = linalg.generic {indexing_maps=[affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types=["parallel"]}
ins(%arg: tensor<256xi32>) outs(%empty: tensor<256xi32>) {
^bb0(%arg0: i32, %arg1: i32):
%plusone = arith.addi %arg0, %one : i32
linalg.yield %plusone : i32
} -> tensor<256xi32>

%empty1 = tensor.empty() : tensor<256xi32>
%minustwo = arith.constant -2 : i32

// expected-remark @below {{Fused op in position 1}}
%1 = linalg.generic {indexing_maps=[affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types=["parallel"]}
ins(%0: tensor<256xi32>) outs(%empty: tensor<256xi32>) {
^bb0(%arg0: i32, %arg1: i32):
%opp = arith.muli %arg0, %minustwo : i32
linalg.yield %opp : i32
} -> tensor<256xi32>

%empty2 = tensor.empty() : tensor<256xi32>
%2 = linalg.generic {tile, indexing_maps=[affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types=["parallel"]}
ins(%0, %1: tensor<256xi32>, tensor<256xi32>) outs(%empty: tensor<256xi32>) {
^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
%sum = arith.addi %arg0, %arg1 : i32
linalg.yield %sum : i32
} -> tensor<256xi32>

return %2 : tensor<256xi32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.fuse_and_yield %0 [32] debug_worklist true : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

func.func @tile_me(%arg: tensor<256xi32>) -> tensor<256xi32> {
%one = arith.constant 1 : i32
%empty = tensor.empty() : tensor<256xi32>
// expected-remark @below {{Fused op in position 1}}
// expected-remark @below {{Fused op in position 2}}
%0 = linalg.generic {indexing_maps=[affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types=["parallel"]}
ins(%arg: tensor<256xi32>) outs(%empty: tensor<256xi32>) {
^bb0(%arg0: i32, %arg1: i32):
%plusone = arith.addi %arg0, %one : i32
linalg.yield %plusone : i32
} -> tensor<256xi32>

%empty1 = tensor.empty() : tensor<256xi32>
%minustwo = arith.constant -2 : i32

// expected-remark @below {{Fused op in position 0}}
%1 = linalg.generic {indexing_maps=[affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types=["parallel"]}
ins(%0: tensor<256xi32>) outs(%empty: tensor<256xi32>) {
^bb0(%arg0: i32, %arg1: i32):
%opp = arith.muli %arg0, %minustwo : i32
linalg.yield %opp : i32
} -> tensor<256xi32>

%empty2 = tensor.empty() : tensor<256xi32>
%2 = linalg.generic {tile, indexing_maps=[affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types=["parallel"]}
ins(%0, %1: tensor<256xi32>, tensor<256xi32>) outs(%empty: tensor<256xi32>) {
^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
%sum = arith.addi %arg0, %arg1 : i32
linalg.yield %sum : i32
} -> tensor<256xi32>

return %2 : tensor<256xi32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.fuse_and_yield %0 [32] debug_worklist true reverse_worklist true : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
/// Apply a tile and fuse transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, unsigned numLoops,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, bool useForall,
TransformResults &transformResults) {
static LogicalResult applyTileAndFuseToAll(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
unsigned numLoops, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, bool useForall, bool debugWorkList,
bool reverseWorkList, TransformResults &transformResults) {
SmallVector<Operation *> tiledOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);

Expand Down Expand Up @@ -87,7 +86,17 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
}

scf::SCFTileAndFuseOptions tileAndFuseOptions;
if (reverseWorkList) {
tileAndFuseOptions.setWorklistInsertFn(
[](tensor::ExtractSliceOp op,
std::deque<tensor::ExtractSliceOp> &worklist) {
worklist.push_front(op);
});
}
tileAndFuseOptions.setTilingOptions(tilingOptions);
if (debugWorkList) {
tileAndFuseOptions.printTilingOrder = true;
}

scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
Expand Down Expand Up @@ -157,7 +166,8 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
LogicalResult result = applyTileAndFuseToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
tileInterchange, getUseForall(), transformResults);
tileInterchange, getUseForall(), getDebugWorklist(), getReverseWorklist(),
transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
DefaultValuedAttr<BoolAttr, "false">:$use_forall);
DefaultValuedAttr<BoolAttr, "false">:$use_forall,
DefaultValuedAttr<BoolAttr, "false">:$debug_worklist,
DefaultValuedAttr<BoolAttr, "false">:$reverse_worklist);
let results = (outs TransformHandleTypeInterface:$transfomed,
Variadic<TransformHandleTypeInterface>:$loops);

let assemblyFormat = [{
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
(`use_forall` $use_forall^)? attr-dict
(`use_forall` $use_forall^)? (`debug_worklist` $debug_worklist^)?
(`reverse_worklist` $reverse_worklist^)? attr-dict
`:` functional-type(operands, results)
}];
}
Expand Down