Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c76a8cc
Make fusion work on any LinalgOp
srcarroll Jun 19, 2025
fa60fa5
Merge branch 'main' into generalize-fusion
srcarroll Jun 19, 2025
20b25f3
format and add test
srcarroll Jun 19, 2025
8e471a7
fix typo in test
srcarroll Jun 19, 2025
d723913
add same test for other fusion pass -linalg-fuse-elementwise-ops
srcarroll Jun 19, 2025
5280b87
fix bug with no output bb args and add test
srcarroll Jun 19, 2025
c2f52bc
add linalg.elementwise test
srcarroll Jun 19, 2025
8d2e8e0
fix formatting
srcarroll Jun 19, 2025
58582bf
use getElementTypeOrSelf for cleanup
srcarroll Jun 19, 2025
cf67ab6
switch elementwise test to broadcast version
srcarroll Jun 19, 2025
7d402c1
remove block args that were added (hacky)
srcarroll Jun 20, 2025
b1d15b2
add requested tests
srcarroll Jun 21, 2025
459bcb4
add checks for nontrivial map cases
srcarroll Jun 21, 2025
f7e164b
revert unintended change
srcarroll Jun 21, 2025
9acb5d0
Merge branch 'main' into generalize-fusion
srcarroll Jul 31, 2025
8a375bf
fix weird lit failure
srcarroll Aug 1, 2025
5c3bd5f
Merge branch 'main' into generalize-fusion
srcarroll Aug 1, 2025
a12b417
Merge branch 'main' into generalize-fusion
srcarroll Oct 30, 2025
ce33596
remove hack for linalg.map and update tests
srcarroll Oct 30, 2025
1eb3461
fix tests again
srcarroll Oct 30, 2025
389a9c4
add a couple tests with `linalg.matmul`
srcarroll Nov 2, 2025
55cca22
Merge branch 'main' into generalize-fusion
srcarroll Nov 2, 2025
b914037
Merge branch 'main' into generalize-fusion
srcarroll Nov 3, 2025
8c906a0
"fix" mysterious check-mlir error
srcarroll Nov 3, 2025
047266b
use proper getBlock interface methods
srcarroll Nov 5, 2025
880d9c0
Merge branch 'main' into generalize-fusion
srcarroll Nov 5, 2025
5588c86
Merge branch 'main' into generalize-fusion
srcarroll Nov 5, 2025
95635df
Merge branch 'main' into generalize-fusion
srcarroll Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
GenericOp consumer,
llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
LinalgOp consumer,
OpOperand *fusedOperand);

/// Try to peel and canonicalize loop `op` and return the new result.
Expand Down
63 changes: 42 additions & 21 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
GenericOp producer, GenericOp consumer,
LinalgOp producer, LinalgOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;

SmallVector<GenericOp> ops = {producer, consumer};
SmallVector<LinalgOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
Expand Down Expand Up @@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int>
mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;

Expand Down Expand Up @@ -140,8 +141,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;

auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());

// Check producer and consumer are generic ops.
if (!producer || !consumer)
Expand Down Expand Up @@ -215,16 +216,39 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
RewriterBase &rewriter, GenericOp fusedOp,
RewriterBase &rewriter, LinalgOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.

// Since some ops, like `linalg.map`, do not have block arguments for init operands
// then we first "generalize" the block by adding arguments for init operands when
// they aren't present. We detect this case by checking if
// `getOpOperandsMatchingBBargs() == getDpsInputOperands();
Block &producerBlock = producer->getRegion(0).front();
if (producer.getOpOperandsMatchingBBargs() ==
producer.getDpsInputOperands()) {
for (auto init : producer.getDpsInits()) {
Type bbType = isa<ShapedType>(init.getType())
? cast<ShapedType>(init.getType()).getElementType()
: init.getType();
producerBlock.addArgument(bbType, producer.getLoc());
}
}
Block &consumerBlock = consumer->getRegion(0).front();
if (consumer.getOpOperandsMatchingBBargs() ==
consumer.getDpsInputOperands()) {
for (auto init : consumer.getDpsInits()) {
Type bbType = isa<ShapedType>(init.getType())
? cast<ShapedType>(init.getType()).getElementType()
: init.getType();
consumerBlock.addArgument(bbType, consumer.getLoc());
}
}
OpBuilder::InsertionGuard guard(rewriter);
Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
Copy link
Member

@Groverkss Groverkss Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. Please use interface methods to do it correctly:

/*methodName=*/"getRegionBuilder",

or

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getRegionBuilder is wrong since that just gets a function ref, but will use getBlock. thanks for pointing it out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm actually unsure how to apply this suggestion. the other one above makes sense since i'm just getting the blocks that have already been created. however this one is creating a block in the op's region and I dont see an interface method for getting a region. If region 0 isn't guaranteed, shouldn't the interface have a method for that?

Anyway, fusedOp doesn't have to be a LinalgOp for these initial changes since the function that calls this one explicitly creates a GenericOp (see https://github.com/llvm/llvm-project/pull/144922/files#diff-a7543973103a3f3abb605911ca6d141dc4ffd4782b2bc0ad57890d11ab72e2c1R422). So it's probably better to just declare fusedOp as GenericOp for this function and revert this line. Any thoughts?

@rengolin had a discussion back when this PR first went up about generating named ops post fusion when possible. I think it was agreed that it makes sense to leave this as a TODO (see #144922 (comment)). So when we get there we can revisit how to do this, unless there's an obvious solution now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, fusedOp doesn't have to be a LinalgOp for these initial changes since the function that calls this one explicitly creates a GenericOp (see https://github.com/llvm/llvm-project/pull/144922/files#diff-a7543973103a3f3abb605911ca6d141dc4ffd4782b2bc0ad57890d11ab72e2c1R422). So it's probably better to just declare fusedOp as GenericOp for this function and revert this line. Any thoughts?

Probably better to declare it as a GenericOp yes, since the transformation always (today) returns a GenericOp anyway.

@rengolin had a discussion back when this PR first went up about generating named ops post fusion when possible. I think it was agreed that it makes sense to leave this as a TODO (see #144922 (comment)). So when we get there we can revisit how to do this, unless there's an obvious solution now.

I think that is a different problem. My main concern is that expecting something from an interface when the interface doesn't gurantee it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a separate problem. i just meant that when we make the changes to fuse to named ops (when possible) then LinalgOp might be re-introduced here and thus run into this issue again. But it is possible that more refactoring would need to happen anyway so wouldn't necessarily run back into this issue again here specifically.

IRMapping mapper;

// 2. Add an index operation for every fused loop dimension and use the
Expand Down Expand Up @@ -330,7 +354,7 @@ static void generateFusedElementwiseOpRegion(
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);

// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
}

Expand All @@ -340,8 +364,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
auto producer = cast<LinalgOp>(producerResult.getOwner());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
Expand Down Expand Up @@ -418,10 +442,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// Generate the fused op.
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.getIteratorTypes(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
Expand Down Expand Up @@ -460,14 +481,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,

namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit),
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
LogicalResult matchAndRewrite(LinalgOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
Expand All @@ -494,7 +515,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
rewriter.eraseOp(genericOp);
return success();
}
return failure();
return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
}

private:
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,24 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]

// -----

func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
%mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
return %mapped_65 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.generic
54 changes: 54 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,57 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-NOT: linalg.generic

// -----

func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
%mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
return %mapped_65 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not for this PR, but we have discussed keeping the named ops for as long as possible. Here, since they're both maps, we could fuse into a map still. Technically, they're the same (as discussed in the forum), but if I have a chain of matches and fusers, I'd have to match against all possible representations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah I mentioned in a comment that I wanted to try to do that, but don't know of a clean way to do that yet. I've done something like that before by just using clone and modifying as a way to generalizing a transform from a named op to that same named op. It seems more complicated here, but maybe not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that wouldn't help with elementwise+elementwise -> map anyway

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also that wouldn't help with elementwise+elementwise -> map anyway

Right, the idea is that ew + ew -> map is still better than to generic. So we only walk up the tree when needed (assuming generic -> map -> ew is the branch we're walking).

Copy link
Contributor Author

@srcarroll srcarroll Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing to note. ew + ew -> map only works if maps for both ew are same rank identity on all operands since indexing maps for map are limited

// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.map

// -----

func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
%init = tensor.empty() : tensor<8xi1>
%initf = tensor.empty() : tensor<8xf32>
%0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
%1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
%2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
(%in0 : f32, %in1 : f32) {
%cmp = arith.cmpf olt, %in0, %in1 : f32
linalg.yield %cmp : i1
}
%3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>)
return %3 : tensor<8xf32>
}

// CHECK-LABEL: func @map_ops_mixed_types
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]]
// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]]
// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]]
// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the duplicate computations are an old artifact. these do go away with cse but let me know if this is something that should be looked at

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed odd. Looks like a bug in the fuser. Could be related to the map vs generic issue you've seen above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did make a generic version of this and ran the old version of the pass and got same results to confirm it's a pre-existing thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the generic version

#map = affine_map<(d0)->(d0)>
func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg3: tensor<8xf32>) -> tensor<8xf32> {
  %init = tensor.empty() : tensor<8xi1>
  %initf = tensor.empty() : tensor<8xf32>
  %0 = linalg.generic {
      indexing_maps = [#map, #map],
      iterator_types = ["parallel"]} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
    ^bb0(%in0 : f32, %out : f32):
        %sqrt = math.sqrt %in0 : f32
        linalg.yield %sqrt : f32
    } -> tensor<8xf32>
  %1 = linalg.generic {
      indexing_maps = [#map, #map],
      iterator_types = ["parallel"]} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
    ^bb0(%in0 : f32, %out : f32):
        %sqrt = math.exp %in0 : f32
        linalg.yield %sqrt : f32
    } -> tensor<8xf32>
  %2 = linalg.generic {
      indexing_maps = [#map, #map, #map],
      iterator_types = ["parallel"]} ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs(%init : tensor<8xi1>) {
    ^bb0(%in0 : f32, %in1 : f32, %out: i1):
      %cmp = arith.cmpf olt, %in0, %in1 : f32
      linalg.yield %cmp : i1
  } -> tensor<8xi1>
  %3 = linalg.generic {
      indexing_maps = [#map, #map, #map, #map],
      iterator_types = ["parallel"]} ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) { 
    ^bb0(%in0 : i1, %in1 : f32, %in2 : f32, %out: f32):
      %select = arith.select %in0, %in1, %in2 : f32
      linalg.yield %select : f32
  } -> tensor<8xf32>
  return %3 : tensor<8xf32>
}

// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
// CHECK-NEXT: linalg.yield %[[RES]]
// CHECK-NOT: linalg.map

Loading