Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
Expand Down Expand Up @@ -194,6 +195,21 @@ struct SCFTileAndFuseOptions {
/// before fusion. This will track deleted and newly inserted
/// `tensor.extract_slice` ops and update the worklist.
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;

/// A function to insert a tilable node into a list of nodes to be tiled.
/// This controls the order in which tiling and fusion happen.
using WorklistInsertFnTy = std::function<void(
tensor::ExtractSliceOp op, std::deque<tensor::ExtractSliceOp> &worklist)>;
/// By default, simply append the op at the end of the queue.
WorklistInsertFnTy worklistInsertFn =
[](tensor::ExtractSliceOp op,
std::deque<tensor::ExtractSliceOp> &worklist) {
worklist.push_back(op);
};
SCFTileAndFuseOptions &setWorklistInsertFn(WorklistInsertFnTy insertFn) {
worklistInsertFn = insertFn;
return *this;
}
};

/// Fuse the producer of the source of `candidateSliceOp` by computing the
Expand Down
19 changes: 14 additions & 5 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,8 @@ namespace {
class SliceTrackingListener : public RewriterBase::Listener {
public:
explicit SliceTrackingListener(
std::optional<FrozenRewritePatternSet> patterns);
std::optional<FrozenRewritePatternSet> patterns,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn);
SliceTrackingListener() = default;

/// Adds the given list of operations to the worklist, and if present,
Expand Down Expand Up @@ -1421,18 +1422,22 @@ class SliceTrackingListener : public RewriterBase::Listener {
/// Optional pattern set to apply when adding new operations to the
/// worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn;
};

SliceTrackingListener::SliceTrackingListener(
std::optional<FrozenRewritePatternSet> p) {
std::optional<FrozenRewritePatternSet> p,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy w) {
patterns = std::move(p);
worklistInsertFn = w;
}

/// Insert extract_slice ops into the worklist.
LogicalResult
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
for (Operation *op : ops) {
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
worklist.push_back(slice);
worklistInsertFn(slice, worklist);
}

if (!patterns)
Expand All @@ -1444,12 +1449,14 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
return applyOpPatternsGreedily(ops, patterns.value(), config);
}

/// Insert extract_slice ops created by cleanup patterns into the worklist.
/// Triggered from applyOpPatternsAndFold() above.
void SliceTrackingListener::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
if (!slice)
return;
worklist.push_back(slice);
worklistInsertFn(slice, worklist);
}

// Scan the worklist for the given op and remove it if present. The
Expand Down Expand Up @@ -1580,7 +1587,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
};

SliceTrackingListener sliceTracker =
SliceTrackingListener(options.cleanupPatterns);
SliceTrackingListener(options.cleanupPatterns, options.worklistInsertFn);

if (failed(
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
Expand All @@ -1596,6 +1603,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
loops);
if (!fusableProducer)
continue;
LLVM_DEBUG(llvm::dbgs() << "worklist: producer is "
<< *(fusableProducer.getOwner()) << "\n");

std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(candidateSlice, fusableProducer,
Expand Down
49 changes: 49 additions & 0 deletions mlir/test/Dialect/Linalg/tile-sort.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -debug-only=tile-using-interface 2>&1 | FileCheck %s

func.func @tile_in_op_operand_order(%arg: tensor<256xf32>) -> tensor<256xf32> {
%empty = tensor.empty() : tensor<256xf32>
%0 = linalg.ceil ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
%empty1 = tensor.empty() : tensor<256xf32>
%1 = linalg.negf ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
%empty2 = tensor.empty() : tensor<256xf32>
%2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>

// CHECK: worklist: producer is %{{.*}} = linalg.ceil
// CHECK: worklist: producer is %{{.*}} = linalg.negf
// CHECK: worklist: producer is %{{.*}} = linalg.ceil

return %2 : tensor<256xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.fuse_and_yield %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

func.func @tile_in_reverse_op_operand_order(%arg: tensor<256xf32>) -> tensor<256xf32> {
%empty = tensor.empty() : tensor<256xf32>
%0 = linalg.ceil ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
%empty1 = tensor.empty() : tensor<256xf32>
%1 = linalg.negf ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
%empty2 = tensor.empty() : tensor<256xf32>
%2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>

// CHECK: worklist: producer is %{{.*}} = linalg.negf
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
// CHECK: worklist: producer is %{{.*}} = linalg.ceil

return %2 : tensor<256xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.fuse_and_yield %0 [32] reverse_worklist true : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
/// Apply a tile and fuse transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, unsigned numLoops,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, bool useForall,
TransformResults &transformResults) {
static LogicalResult applyTileAndFuseToAll(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
unsigned numLoops, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, bool useForall, bool reverseWorkList,
TransformResults &transformResults) {
SmallVector<Operation *> tiledOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);

Expand Down Expand Up @@ -87,6 +86,13 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
}

scf::SCFTileAndFuseOptions tileAndFuseOptions;
if (reverseWorkList) {
tileAndFuseOptions.setWorklistInsertFn(
[](tensor::ExtractSliceOp op,
std::deque<tensor::ExtractSliceOp> &worklist) {
worklist.push_front(op);
});
}
tileAndFuseOptions.setTilingOptions(tilingOptions);

scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
Expand Down Expand Up @@ -157,7 +163,7 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
LogicalResult result = applyTileAndFuseToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
tileInterchange, getUseForall(), transformResults);
tileInterchange, getUseForall(), getReverseWorklist(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
DefaultValuedAttr<BoolAttr, "false">:$use_forall);
DefaultValuedAttr<BoolAttr, "false">:$use_forall,
DefaultValuedAttr<BoolAttr, "false">:$reverse_worklist);
let results = (outs TransformHandleTypeInterface:$transfomed,
Variadic<TransformHandleTypeInterface>:$loops);

let assemblyFormat = [{
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
(`use_forall` $use_forall^)? attr-dict
(`use_forall` $use_forall^)? (`reverse_worklist` $reverse_worklist^)? attr-dict
`:` functional-type(operands, results)
}];
}
Expand Down