Skip to content
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c76a8cc
Make fusion work on any LinalgOp
srcarroll Jun 19, 2025
fa60fa5
Merge branch 'main' into generalize-fusion
srcarroll Jun 19, 2025
20b25f3
format and add test
srcarroll Jun 19, 2025
8e471a7
fix typo in test
srcarroll Jun 19, 2025
d723913
add same test for other fusion pass -linalg-fuse-elementwise-ops
srcarroll Jun 19, 2025
5280b87
fix bug with no output bb args and add test
srcarroll Jun 19, 2025
c2f52bc
add linalg.elementwise test
srcarroll Jun 19, 2025
8d2e8e0
fix formatting
srcarroll Jun 19, 2025
58582bf
use getElementTypeOrSelf for cleanup
srcarroll Jun 19, 2025
cf67ab6
switch elementwise test to broadcast version
srcarroll Jun 19, 2025
7d402c1
remove block args that were added (hacky)
srcarroll Jun 20, 2025
b1d15b2
add requested tests
srcarroll Jun 21, 2025
459bcb4
add checks for nontrivial map cases
srcarroll Jun 21, 2025
f7e164b
revert unintended change
srcarroll Jun 21, 2025
9acb5d0
Merge branch 'main' into generalize-fusion
srcarroll Jul 31, 2025
8a375bf
fix weird lit failure
srcarroll Aug 1, 2025
5c3bd5f
Merge branch 'main' into generalize-fusion
srcarroll Aug 1, 2025
a12b417
Merge branch 'main' into generalize-fusion
srcarroll Oct 30, 2025
ce33596
remove hack for linalg.map and update tests
srcarroll Oct 30, 2025
1eb3461
fix tests again
srcarroll Oct 30, 2025
389a9c4
add a couple tests with `linalg.matmul`
srcarroll Nov 2, 2025
55cca22
Merge branch 'main' into generalize-fusion
srcarroll Nov 2, 2025
b914037
Merge branch 'main' into generalize-fusion
srcarroll Nov 3, 2025
8c906a0
"fix" mysterious check-mlir error
srcarroll Nov 3, 2025
047266b
use proper getBlock interface methods
srcarroll Nov 5, 2025
880d9c0
Merge branch 'main' into generalize-fusion
srcarroll Nov 5, 2025
5588c86
Merge branch 'main' into generalize-fusion
srcarroll Nov 5, 2025
95635df
Merge branch 'main' into generalize-fusion
srcarroll Nov 5, 2025
b7247b9
update docstrings
srcarroll Nov 6, 2025
b639a2f
Merge branch 'main' into generalize-fusion
srcarroll Nov 6, 2025
ac0eb5b
Merge branch 'main' into generalize-fusion
srcarroll Nov 6, 2025
a3bba1a
define separate `fuseElementwiseLinalgOps` and `fuseElementwiseGeneri…
srcarroll Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,11 @@ FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
GenericOp genericOp,
const ControlDropUnitDims &options);

/// Fuse two `linalg.generic` operations that have a producer-consumer
/// Fuse two linalg operations that have a producer-consumer
/// relationship captured through `fusedOperand`. The method expects
/// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
/// The resulting fused operation is always a `linalg.generic`.
/// TODO: Support fusing to named ops when possible.
struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
Expand All @@ -569,8 +571,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
GenericOp consumer,
llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
LinalgOp consumer,
OpOperand *fusedOperand);

/// Try to peel and canonicalize loop `op` and return the new result.
Expand Down Expand Up @@ -1921,8 +1923,10 @@ using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;

/// Patterns for fusing linalg operation on tensors.

/// Pattern to fuse `linalg.generic` -> `linalg.generic` operations
/// when both operations are fusable elementwise operations.
/// Pattern to fuse two linalg operations
/// when both operations are fusable operations.
/// The producer must always be an elementwise operation
/// and operations are opted into fusion via `controlElementwiseOpFusion`.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
Expand Down
61 changes: 30 additions & 31 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
GenericOp producer, GenericOp consumer,
LinalgOp producer, LinalgOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;

SmallVector<GenericOp> ops = {producer, consumer};
SmallVector<LinalgOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
Expand Down Expand Up @@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int>
mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;

Expand All @@ -135,15 +136,15 @@ llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
return preservedProducerResults;
}

/// Conditions for elementwise fusion of generic operations.
/// Conditions for elementwise fusion of linalg operations.
bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;

auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());

// Check producer and consumer are generic ops.
// Check producer and consumer are linalg ops.
if (!producer || !consumer)
return false;

Expand Down Expand Up @@ -179,7 +180,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
return false;

// Ensure that the fusion does not remove size information required to
// get the loop bounds. For non-reduction generics, this is trivially the
// get the loop bounds. For non-reduction ops, this is trivially the
// case due to the output operand. For reductions, we need to check that after
// the fusion, each loop dimension has at least one input that defines it.
if ((consumer.getNumReductionLoops())) {
Expand Down Expand Up @@ -219,13 +220,14 @@ static void generateFusedElementwiseOpRegion(
RewriterBase &rewriter, GenericOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();

Block &producerBlock = *producer.getBlock();
Block &consumerBlock = *consumer.getBlock();
OpBuilder::InsertionGuard guard(rewriter);
Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
Copy link
Member

@Groverkss Groverkss Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. Please use interface methods to do it correctly:

/*methodName=*/"getRegionBuilder",

or

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRegionBuilder is wrong since that just gets a function ref, but will use getBlock. thanks for pointing it out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm actually unsure how to apply this suggestion. the other one above makes sense since i'm just getting the blocks that have already been created. however this one is creating a block in the op's region and I dont see an interface method for getting a region. If region 0 isn't guaranteed, shouldn't the interface have a method for that?

Anyway, fusedOp doesn't have to be a LinalgOp for these initial changes since the function that calls this one explicitly creates a GenericOp (see https://github.com/llvm/llvm-project/pull/144922/files#diff-a7543973103a3f3abb605911ca6d141dc4ffd4782b2bc0ad57890d11ab72e2c1R422). So it's probably better to just declare fusedOp as GenericOp for this function and revert this line. Any thoughts?

@rengolin had a discussion back when this PR first went up about generating named ops post fusion when possible. I think it was agreed that it makes sense to leave this as a TODO (see #144922 (comment)). So when we get there we can revisit how to do this, unless there's an obvious solution now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, fusedOp doesn't have to be a LinalgOp for these initial changes since the function that calls this one explicitly creates a GenericOp (see https://github.com/llvm/llvm-project/pull/144922/files#diff-a7543973103a3f3abb605911ca6d141dc4ffd4782b2bc0ad57890d11ab72e2c1R422). So it's probably better to just declare fusedOp as GenericOp for this function and revert this line. Any thoughts?

Probably better to declare it as a GenericOp yes, since the transformation always (today) returns a GenericOp anyway.

@rengolin had a discussion back when this PR first went up about generating named ops post fusion when possible. I think it was agreed that it makes sense to leave this as a TODO (see #144922 (comment)). So when we get there we can revisit how to do this, unless there's an obvious solution now.

I think that is a different problem. My main concern is that expecting something from an interface when the interface doesn't gurantee it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a separate problem. i just meant that when we make the changes to fuse to named ops (when possible) then LinalgOp might be re-introduced here and thus run into this issue again. But it is possible that more refactoring would need to happen anyway so wouldn't necessarily run back into this issue again here specifically.

IRMapping mapper;

// 2. Add an index operation for every fused loop dimension and use the
Expand Down Expand Up @@ -331,7 +333,7 @@ static void generateFusedElementwiseOpRegion(
YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);

// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
}

Expand All @@ -341,8 +343,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(producerResult.getOwner());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
Expand Down Expand Up @@ -419,10 +421,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// Generate the fused op.
auto fusedOp = GenericOp::create(
rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.getIteratorTypes(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
Expand Down Expand Up @@ -460,18 +459,18 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}

namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
/// Patterns to fuse a linalg op, with the producer of its operands.
class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
// Find the first operand that is defined by another linalg op on tensors.
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
if (!areElementwiseOpsFusable(&opOperand))
continue;
if (!controlFn(&opOperand))
Expand All @@ -483,7 +482,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
FailureOr<ElementwiseOpFusionResult> fusionResult =
fuseElementwiseOps(rewriter, &opOperand);
if (failed(fusionResult))
return rewriter.notifyMatchFailure(genericOp, "fusion failed");
return rewriter.notifyMatchFailure(linalgOp, "fusion failed");

// Perform the fusion.
for (auto [origVal, replacement] : fusionResult->replacements) {
Expand All @@ -492,10 +491,10 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
return use.get().getDefiningOp() != producer;
});
}
rewriter.eraseOp(genericOp);
rewriter.eraseOp(linalgOp);
return success();
}
return failure();
return rewriter.notifyMatchFailure(linalgOp, "no fusable operands");
}

private:
Expand Down Expand Up @@ -2279,7 +2278,7 @@ void mlir::linalg::populateCollapseDimensions(

namespace {

/// Pass that fuses generic ops on tensors. Used only for testing.
/// Pass that fuses linalg ops on tensors. Used only for testing.
// TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
// patterns added here heavily depends on the cost function used. Having an
// opinionated pass of this form is not recommended. Deprecate this pass in
Expand Down
70 changes: 68 additions & 2 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1017,9 +1017,75 @@ module {

// -----

func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
%mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
return %mapped_65 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.generic

// -----

func.func @map_matmul(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rengolin Per your request I haver added this and the test below. Please let me know if there's more you would like to see

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! That's the kind of thing we need to be testing.

If the matmul has a transpose/broadcast/reduction map on %exp then it shouldn't be fused.

This also applies to contract and elementwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the matmul has a transpose/broadcast/reduction map on %exp then it shouldn't be fused.

Ah yes I forgot matmul was extended to allow different indexing maps. will add more cases involving this. Will also check with contract

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the matmul has a transpose/broadcast/reduction map on %exp then it shouldn't be fused.

@rengolin Actually, why not? I think this is only conditionally true. One case I'm thinking of that involves transpose is valid for fusion, for example

func.func @map_matmul_transpose_a(%in1: tensor<10x8xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> {
    %fill0 = tensor.empty() : tensor<10x8xf32>
    %exp = linalg.map {math.exp} ins(%in1 : tensor<10x8xf32>) outs(%fill0: tensor<10x8xf32>)
    %fill1 = tensor.empty() : tensor<8x12xf32>
    %matmul = linalg.matmul indexing_maps = [
                       affine_map<(d0, d1, d2) -> (d2, d0)>,
                       affine_map<(d0, d1, d2) -> (d2, d1)>,
                       affine_map<(d0, d1, d2) -> (d0, d1)>
                     ] ins(%exp, %in2 : tensor<10x8xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32>
    return %matmul : tensor<8x12xf32>
}

would fuse to

#map = affine_map<(d0, d1, d2) -> (d2, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
  func.func @map_matmul_transpose_a(%arg0: tensor<10x8xf32>, %arg1: tensor<10x12xf32>) -> tensor<8x12xf32> {
    %0 = tensor.empty() : tensor<8x12xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10x8xf32>, tensor<10x12xf32>) outs(%0 : tensor<8x12xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %2 = math.exp %in : f32
      %3 = arith.mulf %2, %in_0 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<8x12xf32>
    return %1 : tensor<8x12xf32>
  }
}

broadcast cases can also be valid, for example

func.func @map_matmul_bcast(%in1: tensor<10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> {
  %fill0 = tensor.empty() : tensor<10xf32>
  %exp = linalg.map {math.exp} ins(%in1 : tensor<10xf32>) outs(%fill0: tensor<10xf32>)
  %fill1 = tensor.empty() : tensor<8x12xf32>
  %matmul = linalg.matmul indexing_maps = [
                     affine_map<(d0, d1, d2) -> (d2)>,
                     affine_map<(d0, d1, d2) -> (d2, d1)>,
                     affine_map<(d0, d1, d2) -> (d0, d1)>
                   ] ins(%exp, %in2 : tensor<10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32>
  return %matmul : tensor<8x12xf32>
}

fuses to

#map = affine_map<(d0, d1, d2) -> (d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
  func.func @map_matmul_bcast(%arg0: tensor<10xf32>, %arg1: tensor<10x12xf32>) -> tensor<8x12xf32> {
    %0 = tensor.empty() : tensor<8x12xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10xf32>, tensor<10x12xf32>) outs(%0 : tensor<8x12xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %2 = math.exp %in : f32
      %3 = arith.mulf %2, %in_0 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<8x12xf32>
    return %1 : tensor<8x12xf32>
  }
}

I'll still need to think about cases, with valid input IR, that should NOT result in fusion for the elementwise + matmul case, but I think I will need a more complex case than this to show that. But just wanted to make sure it is agreed that the above cases are valid fusions cases. And again, if you have a specific test case in mind that I'm not thinking of, I will certainly investigate/add it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the above were produced with the builtin linalg-fuse-elementwise-ops pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I should have said "may not be fused" instead.

Copy link
Contributor Author

@srcarroll srcarroll Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool thanks.

And again, if you have a specific test case in mind that I'm not thinking of, I will certainly investigate/add it.

btw, it is not my intention to burden you with coming up with test cases for these changes. I completely accept that responsibility myself. so i just mean if you already have something in mind, please share. so this is more about hoping to get help than expecting it. i will continue adding more cases that i think are illustrative enough of the scope of changes here

%fill0 = tensor.empty() : tensor<8x10xf32>
%exp = linalg.map {math.exp} ins(%in1 : tensor<8x10xf32>) outs(%fill0: tensor<8x10xf32>)
%fill1 = tensor.empty() : tensor<8x12xf32>
%matmul = linalg.matmul ins(%exp, %in2 : tensor<8x10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32>
return %matmul : tensor<8x12xf32>
}

// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @map_matmul
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8x10xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<10x12xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x12xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[IN0]]
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXP]], %[[IN1]]
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]]
// CHECK-NEXT: linalg.yield %[[ADD]]
// CHECK-NOT: linalg.generic

// -----

func.func @matmul_map(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> {
%fill1 = tensor.empty() : tensor<8x12xf32>
%matmul = linalg.matmul ins(%in1, %in2 : tensor<8x10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32>
%exp = linalg.map {math.exp} ins(%matmul : tensor<8x12xf32>) outs(%fill1: tensor<8x12xf32>)

return %exp : tensor<8x12xf32>
}

// Should not fuse
// CHECK-LABEL: func @matmul_map
// CHECK-NEXT: tensor.empty
// CHECK-NEXT: linalg.matmul
// CHECK-NEXT: linalg.map
// CHECK-NEXT: return

// -----

// In this test we expect the first two linalg.generic operations to be fused into one, but the third one (the matmul) to remain separate.
// The reason is that when the pattern is applied the 1st time, the fusion of the first two operations produces a fused operation with
// an additional result and ana dditional output indexing map that is not a permutation / not invertible.
// an additional result and an additional output indexing map that is not a permutation / not invertible.
// The fused op will still produce also the original result (and its output indexing map), which is preserved because the new indexing map
// is not invertible. Thus the fused op will have 2 results, but only the 2nd one will be used by the following matmul op as an input argument.
// When trying to apply the fusion pattern again, the matmul op won't be fused because the operand to fuse was not produced with an invertible indexing map.
Expand Down Expand Up @@ -1079,4 +1145,4 @@ module {
// CHECK-NOT: linalg.generic
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
Loading
Loading