Skip to content

Commit c76a8cc

Browse files
committed
Make fusion work on any LinalgOp
1 parent b270525 commit c76a8cc

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
511511
/// * There is a chance that the implementation of the transformation does not
512512
/// agree with the result of this method. This function gives a prediction based
513513
/// on an optimized fusion.
514-
llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
515-
GenericOp consumer,
514+
llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
515+
LinalgOp consumer,
516516
OpOperand *fusedOperand);
517517

518518
/// Try to peel and canonicalize loop `op` and return the new result.

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
7575
// of the fused producer & consumer after the fusion can still compute the
7676
// bounds of the op.
7777
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
78-
GenericOp producer, GenericOp consumer,
78+
LinalgOp producer, LinalgOp consumer,
7979
ArrayRef<OpOperand *> opOperandsToIgnore) {
8080
SmallVector<AffineMap> indexingMaps;
8181

82-
SmallVector<GenericOp> ops = {producer, consumer};
82+
SmallVector<LinalgOp> ops = {producer, consumer};
8383
for (auto &op : ops) {
8484
for (auto &opOperand : op->getOpOperands()) {
8585
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
@@ -108,7 +108,7 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
108108
/// agree with the result of this method. This function gives a prediction based
109109
/// on an optimized fusion.
110110
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
111-
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
111+
LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
112112
llvm::SmallDenseSet<int> preservedProducerResults;
113113
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
114114

@@ -138,8 +138,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
138138
if (!fusedOperand)
139139
return false;
140140

141-
auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
142-
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
141+
auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
142+
auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());
143143

144144
// Check producer and consumer are generic ops.
145145
if (!producer || !consumer)
@@ -213,16 +213,16 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
213213
/// Generate the region of the fused tensor operation. The region of the fused
214214
/// op must be empty.
215215
static void generateFusedElementwiseOpRegion(
216-
RewriterBase &rewriter, GenericOp fusedOp,
216+
RewriterBase &rewriter, LinalgOp fusedOp,
217217
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
218218
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
219-
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
220-
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
219+
auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
220+
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
221221
// Build the region of the fused op.
222222
Block &producerBlock = producer->getRegion(0).front();
223223
Block &consumerBlock = consumer->getRegion(0).front();
224224
OpBuilder::InsertionGuard guard(rewriter);
225-
Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
225+
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
226226
IRMapping mapper;
227227

228228
// 2. Add an index operation for every fused loop dimension and use the
@@ -329,7 +329,7 @@ static void generateFusedElementwiseOpRegion(
329329
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
330330

331331
// Sanity checks.
332-
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
332+
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
333333
"Ill-formed GenericOp region");
334334
}
335335

@@ -339,8 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
339339
assert(areElementwiseOpsFusable(fusedOperand) &&
340340
"expected elementwise operation pre-conditions to pass");
341341
auto producerResult = cast<OpResult>(fusedOperand->get());
342-
auto producer = cast<GenericOp>(producerResult.getOwner());
343-
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
342+
auto producer = cast<LinalgOp>(producerResult.getOwner());
343+
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
344344
// TODO: allow fusing the producer of an output operand.
345345
assert(consumer.isDpsInput(fusedOperand) &&
346346
"expected producer of input operand");
@@ -415,12 +415,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
415415
}
416416

417417
// Generate the fused op.
418+
// auto fusedOp = cloneWithoutRegions(rewriter, consumer,
419+
// fusedResultTypes, fusedInputOperands);
420+
// fusedOp.setIndexingMapsAttr(idxMap);
421+
// fusedOp.setIteratorTypesAttr(itTp);
418422
auto fusedOp = rewriter.create<GenericOp>(
419423
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
420-
fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
421-
consumer.getIteratorTypes(),
422-
/*doc=*/nullptr,
423-
/*library_call=*/nullptr);
424+
fusedOutputOperands, fusedIndexMaps,
425+
consumer.getIteratorTypesArray());
424426
if (!fusedOp.getShapesToLoopsMap()) {
425427
// Fused op has invalid indexing maps. Typically this means something is off
426428
// in the input, but going ahead here would result in verification errors.
@@ -459,14 +461,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
459461

460462
namespace {
461463
/// Patterns to fuse a generic op, with the producer of its operands.
462-
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
464+
class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
463465
public:
464466
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
465467
PatternBenefit benefit = 1)
466-
: OpRewritePattern<GenericOp>(context, benefit),
468+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
467469
controlFn(std::move(fun)) {}
468470

469-
LogicalResult matchAndRewrite(GenericOp genericOp,
471+
LogicalResult matchAndRewrite(LinalgOp genericOp,
470472
PatternRewriter &rewriter) const override {
471473
// Find the first operand that is defined by another generic op on tensors.
472474
for (OpOperand &opOperand : genericOp->getOpOperands()) {
@@ -493,7 +495,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
493495
rewriter.eraseOp(genericOp);
494496
return success();
495497
}
496-
return failure();
498+
return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
497499
}
498500

499501
private:

0 commit comments

Comments
 (0)