Skip to content

Commit f0c82ea

Browse files
committed
Enable callback when tiling and fuse
1 parent 878061b commit f0c82ea

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ struct SCFTileAndFuseResult {
110110
SmallVector<Operation *> tiledAndFusedOps;
111111
SmallVector<scf::ForOp> loops;
112112
};
113+
114+
using checkProducerFn =
115+
std::function<LogicalResult(ArrayRef<Range> rootIterationDomain,
116+
Operation *producer, OpBuilder &builder)>;
117+
113118
struct TileConsumerAndFuseProducersUsingSCFForOp
114119
: public OpInterfaceRewritePattern<TilingInterface> {
115120

@@ -127,7 +132,8 @@ struct TileConsumerAndFuseProducersUsingSCFForOp
127132
/// `matchAndRewrite` implementation that returns the significant transformed
128133
/// pieces of IR.
129134
FailureOr<SCFTileAndFuseResult>
130-
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
135+
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter,
136+
checkProducerFn = nullptr) const;
131137

132138
LogicalResult matchAndRewrite(TilingInterface op,
133139
PatternRewriter &rewriter) const override {

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,14 @@ static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
397397

398398
FailureOr<scf::SCFTileAndFuseResult>
399399
scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
400-
TilingInterface op, PatternRewriter &rewriter) const {
400+
TilingInterface op, PatternRewriter &rewriter, checkProducerFn fn) const {
401401
// This transformation is only valid for ops that return values (i.e. not
402402
// valid to use with operations that have memref operands).
403403
if (!op->getNumResults()) {
404404
return rewriter.notifyMatchFailure(
405405
op, "invalid pattern for op with no results");
406406
}
407+
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
407408

408409
// 1. First tile the consumer.
409410
SCFTileAndFuseResult tileAndFuseResult;
@@ -446,6 +447,10 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
446447
if (!fusableProducer)
447448
continue;
448449

450+
if (fn &&
451+
failed(fn(iterationDomain, fusableProducer->getDefiningOp(), rewriter)))
452+
continue;
453+
449454
// 2c. Generate the tiled implementation of the producer of the source
450455
rewriter.setInsertionPoint(candidateSliceOp);
451456
FailureOr<Value> fusedProducerValue =

0 commit comments

Comments
 (0)