Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
Expand Down Expand Up @@ -194,6 +195,21 @@ struct SCFTileAndFuseOptions {
/// before fusion. This will track deleted and newly inserted
/// `tensor.extract_slice` ops and update the worklist.
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;

/// A function to insert a tilable node into a list of nodes to be tiled.
/// This controls the order in which tiling and fusion happen.
using WorklistInsertFnTy = std::function<void(
tensor::ExtractSliceOp op, std::deque<tensor::ExtractSliceOp> &worklist)>;
/// By default, simply append the op at the end of the queue.
WorklistInsertFnTy worklistInsertFn =
[](tensor::ExtractSliceOp op,
std::deque<tensor::ExtractSliceOp> &worklist) {
worklist.push_back(op);
};
SCFTileAndFuseOptions &setWorklistInsertFn(WorklistInsertFnTy insertFn) {
worklistInsertFn = insertFn;
return *this;
}
};

/// Fuse the producer of the source of `candidateSliceOp` by computing the
Expand Down
19 changes: 14 additions & 5 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,8 @@ namespace {
class SliceTrackingListener : public RewriterBase::Listener {
public:
explicit SliceTrackingListener(
std::optional<FrozenRewritePatternSet> patterns);
std::optional<FrozenRewritePatternSet> patterns,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn);
SliceTrackingListener() = default;

/// Adds the given list of operations to the worklist, and if present,
Expand Down Expand Up @@ -1421,18 +1422,22 @@ class SliceTrackingListener : public RewriterBase::Listener {
/// Optional pattern set to apply when adding new operations to the
/// worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
scf::SCFTileAndFuseOptions::WorklistInsertFnTy worklistInsertFn;
};

SliceTrackingListener::SliceTrackingListener(
std::optional<FrozenRewritePatternSet> p) {
std::optional<FrozenRewritePatternSet> p,
scf::SCFTileAndFuseOptions::WorklistInsertFnTy w) {
patterns = std::move(p);
worklistInsertFn = w;
}

/// Insert extract_slice ops into the worklist.
LogicalResult
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
for (Operation *op : ops) {
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
worklist.push_back(slice);
worklistInsertFn(slice, worklist);
}

if (!patterns)
Expand All @@ -1444,12 +1449,14 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
return applyOpPatternsGreedily(ops, patterns.value(), config);
}

/// Insert extract_slice ops created by cleanup patterns into the worklist.
/// Triggered from applyOpPatternsAndFold() above.
void SliceTrackingListener::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
if (!slice)
return;
worklist.push_back(slice);
worklistInsertFn(slice, worklist);
}

// Scan the worklist for the given op and remove it if present. The
Expand Down Expand Up @@ -1580,7 +1587,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
};

SliceTrackingListener sliceTracker =
SliceTrackingListener(options.cleanupPatterns);
SliceTrackingListener(options.cleanupPatterns, options.worklistInsertFn);

if (failed(
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
Expand All @@ -1596,6 +1603,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
loops);
if (!fusableProducer)
continue;
LLVM_DEBUG(llvm::dbgs() << "worklist: producer is "
<< *(fusableProducer.getOwner()) << "\n");

std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(candidateSlice, fusableProducer,
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Linalg/tile-sort.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -debug-only=tile-using-interface 2>&1 | FileCheck %s

func.func @tile_in_op_operand_order(%arg: tensor<256xf32>) -> tensor<256xf32> {
%empty = tensor.empty() : tensor<256xf32>
%0 = linalg.ceil ins(%arg: tensor<256xf32>) outs(%empty: tensor<256xf32>) -> tensor<256xf32>
%empty1 = tensor.empty() : tensor<256xf32>
%1 = linalg.negf ins(%0 : tensor<256xf32>) outs(%empty1: tensor<256xf32>) -> tensor<256xf32>
%empty2 = tensor.empty() : tensor<256xf32>
%2 = linalg.powf {tile} ins(%0, %1: tensor<256xf32>, tensor<256xf32>) outs(%empty2: tensor<256xf32>) -> tensor<256xf32>

// CHECK: worklist: producer is %{{.*}} = linalg.ceil
// CHECK: worklist: producer is %{{.*}} = linalg.negf
// CHECK: worklist: producer is %{{.*}} = linalg.ceil

return %2 : tensor<256xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.powf"]} attributes {"tile"} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops = transform.test.fuse_and_yield %0 [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}