Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
Expand Down Expand Up @@ -194,6 +195,21 @@ struct SCFTileAndFuseOptions {
/// before fusion. This will track deleted and newly inserted
/// `tensor.extract_slice` ops and update the worklist.
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;

/// A function to insert a tilable node into a list of nodes to be tiled.
/// This controls the order in which tiling and fusion happen.
using WorklistInsertFnTy = std::function<void(
tensor::ExtractSliceOp op, std::deque<tensor::ExtractSliceOp> &worklist)>;
/// By default, simply append the op at the end of the queue.
WorklistInsertFnTy worklistInsertFn =
[](tensor::ExtractSliceOp op,
std::deque<tensor::ExtractSliceOp> &worklist) {
worklist.push_back(op);
};
SCFTileAndFuseOptions &setWorklistInsertFn(WorklistInsertFnTy insertFn) {
worklistInsertFn = insertFn;
return *this;
}
};

/// Fuse the producer of the source of `candidateSliceOp` by computing the
Expand Down
19 changes: 14 additions & 5 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,8 @@ namespace {
class SliceTrackingListener : public RewriterBase::Listener {
public:
explicit SliceTrackingListener(
std::optional<FrozenRewritePatternSet> patterns);
std::optional<FrozenRewritePatternSet> patterns,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn);
SliceTrackingListener() = default;

/// Adds the given list of operations to the worklist, and if present,
Expand Down Expand Up @@ -1421,18 +1422,22 @@ class SliceTrackingListener : public RewriterBase::Listener {
/// Optional pattern set to apply when adding new operations to the
/// worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn;
};

SliceTrackingListener::SliceTrackingListener(
std::optional<FrozenRewritePatternSet> p) {
std::optional<FrozenRewritePatternSet> p,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy w) {
patterns = std::move(p);
worklistInsertFn = w;
}

/// Insert extract_slice ops into the worklist.
LogicalResult
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
for (Operation *op : ops) {
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
worklist.push_back(slice);
worklistInsertFn(slice, worklist);
}

if (!patterns)
Expand All @@ -1444,12 +1449,14 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
return applyOpPatternsGreedily(ops, patterns.value(), config);
}

/// Insert extract_slice ops created by cleanup patterns into the worklist.
/// Triggered from applyOpPatternsAndFold() above.
void SliceTrackingListener::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
if (!slice)
return;
worklist.push_back(slice);
worklistInsertFn(slice, worklist);
}

// Scan the worklist for the given op and remove it if present. The
Expand Down Expand Up @@ -1580,7 +1587,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
};

SliceTrackingListener sliceTracker =
SliceTrackingListener(options.cleanupPatterns);
SliceTrackingListener(options.cleanupPatterns, options.worklistInsertFn);

if (failed(
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
Expand All @@ -1596,6 +1603,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
loops);
if (!fusableProducer)
continue;
LLVM_DEBUG(llvm::dbgs() << "worklist: producer is "
<< *(fusableProducer.getOwner()) << "\n");

std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(candidateSlice, fusableProducer,
Expand Down
83 changes: 83 additions & 0 deletions mlir/test/Dialect/Linalg/tile-sort.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -debug-only=tile-using-interface 2>&1 | FileCheck %s

func.func @tile_order_ceil_then_negf(%arg: tensor<256xf32>) -> tensor<256xf32> {
// Ops are tiled by lower priority: linalg.powf, linalg.ceil (1st operand of powf, priority = 0),
// linalg.negf (2nd operand of powf, priority = 1), linalg.ceil (operand of negf, priority = 0)
%empty = tensor.empty() : tensor<256xf32>
%0 = linalg.ceil {tiling_priority = 0 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
%empty1 = tensor.empty() : tensor<256xf32>
%1 = linalg.negf {tiling_priority = 1 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
%empty2 = tensor.empty() : tensor<256xf32>
%2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>

// The order of these checks is the order in which the ops are actually tiled.
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
// CHECK: worklist: producer is %{{.*}} = linalg.negf
// CHECK: worklist: producer is %{{.*}} = linalg.ceil

return %2 : tensor<256xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

func.func @tile_negf_then_ceil(%arg: tensor<256xf32>) -> tensor<256xf32> {
// Ops are tiled by lower priority: linalg.powf, linalg.negf (2nd operand of powf, priority = 0),
// linalg.ceil (1st oprand of powf, priority = 1), linalg.ceil (operand of negf, priority = 1)
%empty = tensor.empty() : tensor<256xf32>
%0 = linalg.ceil {tiling_priority = 1 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
%empty1 = tensor.empty() : tensor<256xf32>
%1 = linalg.negf {tiling_priority = 0 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
%empty2 = tensor.empty() : tensor<256xf32>
%2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>

// CHECK: worklist: producer is %{{.*}} = linalg.negf
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
// CHECK: worklist: producer is %{{.*}} = linalg.ceil

return %2 : tensor<256xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

func.func @tile_negf_then_ceil_swap_in_powf(%arg: tensor<256xf32>) -> tensor<256xf32> {
// This gives the same tiling order as above regardless of the operand order in the powf
// linalg.powf, linalg.negf (1st operand of powf, priority = 0),
// linalg.ceil (2nd oprand of powf, priority = 1), linalg.ceil (operand of negf, priority = 1)
%empty = tensor.empty() : tensor<256xf32>
%0 = linalg.ceil {tiling_priority = 1 : i64} ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
%empty1 = tensor.empty() : tensor<256xf32>
%1 = linalg.negf {tiling_priority = 0 : i64} ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
%empty2 = tensor.empty() : tensor<256xf32>
%2 = linalg.powf {tile} ins(%1, %0: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>

// CHECK: worklist: producer is %{{.*}} = linalg.negf
// CHECK: worklist: producer is %{{.*}} = linalg.ceil
// CHECK: worklist: producer is %{{.*}} = linalg.ceil

return %2 : tensor<256xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.tile_fuse_ordered %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/TilingInterface.h"
#include <deque>
#include <functional>
#include <string>

#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.h.inc"
Expand Down Expand Up @@ -54,12 +59,13 @@ static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
/// Apply a tile and fuse transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, unsigned numLoops,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, bool useForall,
TransformResults &transformResults) {
static LogicalResult applyTileAndFuseToAll(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
unsigned numLoops, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, bool useForall,
TransformResults &transformResults,
std::optional<scf::SCFTileAndFuseOptions::WorklistInsertFnTy>
insertIntoWorklist) {
SmallVector<Operation *> tiledOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);

Expand Down Expand Up @@ -87,6 +93,9 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
}

scf::SCFTileAndFuseOptions tileAndFuseOptions;
if (insertIntoWorklist.has_value()) {
tileAndFuseOptions.setWorklistInsertFn(*insertIntoWorklist);
}
tileAndFuseOptions.setTilingOptions(tilingOptions);

scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
Expand Down Expand Up @@ -157,7 +166,65 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
LogicalResult result = applyTileAndFuseToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
tileInterchange, getUseForall(), transformResults);
tileInterchange, getUseForall(), transformResults, {});
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TestFuseOrderedOp
//===----------------------------------------------------------------------===//

static std::optional<int64_t>
getProducerTilingPriority(tensor::ExtractSliceOp op) {
auto *producer = op.getSource().getDefiningOp();
if (!producer)
return {};

if (!producer->hasAttrOfType<IntegerAttr>("tiling_priority"))
return {};

auto attr = producer->getAttrOfType<IntegerAttr>("tiling_priority");
return attr.getInt();
}

static void
insertIntoWorklistOrdered(tensor::ExtractSliceOp op,
std::deque<tensor::ExtractSliceOp> &worklist) {
std::optional<int64_t> opTilingOrder = getProducerTilingPriority(op);
if (!opTilingOrder) {
worklist.push_back(op);
return;
}

auto iterator = worklist.begin();
for (; iterator != worklist.end(); ++iterator) {
std::optional<int64_t> otherOpTilingOrder =
getProducerTilingPriority(*iterator);
if (!otherOpTilingOrder || *otherOpTilingOrder > *opTilingOrder)
break;
}
worklist.insert(iterator, op);
}

DiagnosedSilenceableFailure
transform::TestFuseOrderedOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
SmallVector<int64_t> tileSizes =
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
SmallVector<int64_t> tileInterchange;
for (size_t i = 0; i < tileSizes.size(); ++i) {
tileInterchange.push_back(i);
}

SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);

LogicalResult result = applyTileAndFuseToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, {}, false,
transformResults, insertIntoWorklistOrdered);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}];
}

def TestFuseOrderedOp : Op<Transform_Dialect, "test.tile_fuse_ordered",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Applies tiling and fusion to the operations pointed to by the target handle,
following the order given by each operation's tiling_priority attribute.

On success returns the tiled operations as well as generated loops. Emits
a definite failure if tiling fails.
}];

let arguments =
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs TransformHandleTypeInterface:$transfomed,
Variadic<TransformHandleTypeInterface>:$loops);

let assemblyFormat = [{
$target ($tile_sizes^)? attr-dict `:` functional-type(operands, results)
}];
}

def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down