From c76a8ccd542376b2cf00e4fbcc1da3c38c1a1f8e Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 11:02:38 -0500 Subject: [PATCH 01/19] Make fusion work on any LinalgOp --- .../Dialect/Linalg/Transforms/Transforms.h | 4 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 ++++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 1dc700f22c202..0420edba2b300 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -511,8 +511,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); /// * There is a chance that the implementation of the transformation does not /// agree with the result of this method. This function gives a prediction based /// on an optimized fusion. -llvm::SmallDenseSet getPreservedProducerResults(GenericOp producer, - GenericOp consumer, +llvm::SmallDenseSet getPreservedProducerResults(LinalgOp producer, + LinalgOp consumer, OpOperand *fusedOperand); /// Try to peel and canonicalize loop `op` and return the new result. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3a57f368d4425..498563e605fef 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -75,11 +75,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( // of the fused producer & consumer after the fusion can still compute the // bounds of the op. static bool isOpOperandCanBeDroppedAfterFusedLinalgs( - GenericOp producer, GenericOp consumer, + LinalgOp producer, LinalgOp consumer, ArrayRef opOperandsToIgnore) { SmallVector indexingMaps; - SmallVector ops = {producer, consumer}; + SmallVector ops = {producer, consumer}; for (auto &op : ops) { for (auto &opOperand : op->getOpOperands()) { if (llvm::is_contained(opOperandsToIgnore, &opOperand)) { @@ -108,7 +108,7 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( /// agree with the result of this method. This function gives a prediction based /// on an optimized fusion. llvm::SmallDenseSet mlir::linalg::getPreservedProducerResults( - GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) { + LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) { llvm::SmallDenseSet preservedProducerResults; llvm::SmallVector opOperandsToIgnore; @@ -138,8 +138,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { if (!fusedOperand) return false; - auto producer = fusedOperand->get().getDefiningOp(); - auto consumer = dyn_cast(fusedOperand->getOwner()); + auto producer = fusedOperand->get().getDefiningOp(); + auto consumer = dyn_cast(fusedOperand->getOwner()); // Check producer and consumer are generic ops. if (!producer || !consumer) @@ -213,16 +213,16 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. static void generateFusedElementwiseOpRegion( - RewriterBase &rewriter, GenericOp fusedOp, + RewriterBase &rewriter, LinalgOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet &preservedProducerResults) { - auto producer = cast(fusedOperand->get().getDefiningOp()); - auto consumer = cast(fusedOperand->getOwner()); + auto producer = cast(fusedOperand->get().getDefiningOp()); + auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); OpBuilder::InsertionGuard guard(rewriter); - Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); + Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0)); IRMapping mapper; // 2. Add an index operation for every fused loop dimension and use the @@ -329,7 +329,7 @@ static void generateFusedElementwiseOpRegion( rewriter.create(fusedOp.getLoc(), fusedYieldValues); // Sanity checks. - assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && + assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() && "Ill-formed GenericOp region"); } @@ -339,8 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, assert(areElementwiseOpsFusable(fusedOperand) && "expected elementwise operation pre-conditions to pass"); auto producerResult = cast(fusedOperand->get()); - auto producer = cast(producerResult.getOwner()); - auto consumer = cast(fusedOperand->getOwner()); + auto producer = cast(producerResult.getOwner()); + auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); @@ -415,12 +415,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, } // Generate the fused op. + // auto fusedOp = cloneWithoutRegions(rewriter, consumer, + // fusedResultTypes, fusedInputOperands); + // fusedOp.setIndexingMapsAttr(idxMap); + // fusedOp.setIteratorTypesAttr(itTp); auto fusedOp = rewriter.create( consumer.getLoc(), fusedResultTypes, fusedInputOperands, - fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), - consumer.getIteratorTypes(), - /*doc=*/nullptr, - /*library_call=*/nullptr); + fusedOutputOperands, fusedIndexMaps, + consumer.getIteratorTypesArray()); if (!fusedOp.getShapesToLoopsMap()) { // Fused op has invalid indexing maps. Typically this means something is off // in the input, but going ahead here would result in verification errors. @@ -459,14 +461,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, namespace { /// Patterns to fuse a generic op, with the producer of its operands. -class FuseElementwiseOps : public OpRewritePattern { +class FuseElementwiseOps : public OpInterfaceRewritePattern { public: FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : OpInterfaceRewritePattern(context, benefit), controlFn(std::move(fun)) {} - LogicalResult matchAndRewrite(GenericOp genericOp, + LogicalResult matchAndRewrite(LinalgOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. for (OpOperand &opOperand : genericOp->getOpOperands()) { @@ -493,7 +495,7 @@ class FuseElementwiseOps : public OpRewritePattern { rewriter.eraseOp(genericOp); return success(); } - return failure(); + return rewriter.notifyMatchFailure(genericOp, "no fusable operands"); } private: From 20b25f3b4b75a67fcadb94720fb13b915ce1bc29 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 11:35:37 -0500 Subject: [PATCH 02/19] format and add test --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 12 ++++------- .../Dialect/Linalg/fusion-elementwise.mlir | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 0b5e3d1b123b3..688244f44cbe7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( /// * There is a chance that the implementation of the transformation does not /// agree with the result of this method. This function gives a prediction based /// on an optimized fusion. -llvm::SmallDenseSet mlir::linalg::getPreservedProducerResults( - LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) { +llvm::SmallDenseSet +mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer, + OpOperand *fusedOperand) { llvm::SmallDenseSet preservedProducerResults; llvm::SmallVector opOperandsToIgnore; @@ -416,14 +417,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, } // Generate the fused op. - // auto fusedOp = cloneWithoutRegions(rewriter, consumer, - // fusedResultTypes, fusedInputOperands); - // fusedOp.setIndexingMapsAttr(idxMap); - // fusedOp.setIteratorTypesAttr(itTp); auto fusedOp = rewriter.create( consumer.getLoc(), fusedResultTypes, fusedInputOperands, - fusedOutputOperands, fusedIndexMaps, - consumer.getIteratorTypesArray()); + fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray()); if (!fusedOp.getShapesToLoopsMap()) { // Fused op has invalid indexing maps. Typically this means something is off // in the input, but going ahead here would result in verification errors. diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index bd9977f1410b9..db24d6d5f027a 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -59,3 +59,24 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> // CHECK: %[[FUSED_OP:.+]] = linalg.generic // CHECK-SAME: outs(%[[EMPTY]] : // CHECK-NOT: linalg.generic + +// ----- + +func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) + return %mapped_65 : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.generic From 8e471a750a962feea17d99c27bf2bdb17a991ad1 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 13:23:22 -0500 Subject: [PATCH 03/19] fix typo in test --- mlir/test/Dialect/Linalg/fusion-elementwise.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index db24d6d5f027a..9b5f3d12f3d21 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -74,7 +74,7 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> // CHECK: %[[FUSED_OP:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.}}) outs(%[[EMPTY]] : +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] // CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] From d723913f901841e3f8b6ee7ee4b71ec2e66e30ab Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 13:47:52 -0500 Subject: [PATCH 04/19] add same test for other fusion pass -linalg-fuse-elementwise-ops --- .../Linalg/fusion-elementwise-ops.mlir | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 66fc55fadf8fa..b581567cf57a7 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1014,3 +1014,24 @@ module { // CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]] // CHECK: linalg.yield %[[T3]] : f32 // CHECK: return %[[GENERIC]] + +// ----- + +func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) + return %mapped_65 : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.generic From 5280b873e345c7976b8deee5f01cdba354d6df28 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 16:08:02 -0500 Subject: [PATCH 05/19] fix bug with no output bb args and add test --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 23 ++++++++++++ .../Dialect/Linalg/fusion-elementwise.mlir | 35 ++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 688244f44cbe7..fc435b47f5977 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -222,8 +222,31 @@ static void generateFusedElementwiseOpRegion( auto producer = cast(fusedOperand->get().getDefiningOp()); auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. + + // Since some ops, like `linalg.map`, do not have block arguments for init operands + // then we first "generalize" the block by adding arguments for init operands when + // they aren't present. We detect this case by checking if + // `getOpOperandsMatchingBBargs() == getDpsInputOperands(); Block &producerBlock = producer->getRegion(0).front(); + if (producer.getOpOperandsMatchingBBargs() == + producer.getDpsInputOperands()) { + for (auto init : producer.getDpsInits()) { + Type bbType = isa(init.getType()) + ? cast(init.getType()).getElementType() + : init.getType(); + producerBlock.addArgument(bbType, producer.getLoc()); + } + } Block &consumerBlock = consumer->getRegion(0).front(); + if (consumer.getOpOperandsMatchingBBargs() == + consumer.getDpsInputOperands()) { + for (auto init : consumer.getDpsInits()) { + Type bbType = isa(init.getType()) + ? cast(init.getType()).getElementType() + : init.getType(); + consumerBlock.addArgument(bbType, consumer.getLoc()); + } + } OpBuilder::InsertionGuard guard(rewriter); Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0)); IRMapping mapper; diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 9b5f3d12f3d21..18ca8b42fa79c 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -79,4 +79,37 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] // CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] // CHECK-NEXT: linalg.yield %[[SQRT]] -// CHECK-NOT: linalg.generic +// CHECK-NOT: linalg.map + +// ----- + +func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { + %init = tensor.empty() : tensor<8xi1> + %initf = tensor.empty() : tensor<8xf32> + %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) + %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) + %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>) + (%in0 : f32, %in1 : f32) { + %cmp = arith.cmpf olt, %in0, %in1 : f32 + linalg.yield %cmp : i1 + } + %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) + return %3 : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops_mixed_types +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]] +// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]] +// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]] +// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]] +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]] +// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]] +// CHECK-NEXT: linalg.yield %[[RES]] +// CHECK-NOT: linalg.map + From c2f52bc4154b62281bfcd8521154faf81e04c1f1 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 18:11:45 -0500 Subject: [PATCH 06/19] add linalg.elementwise test --- .../Dialect/Linalg/fusion-elementwise.mlir | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 18ca8b42fa79c..2f9011cd5e52b 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -65,8 +65,8 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { %fill = tensor.empty() : tensor<8xf32> %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) - %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) - return %mapped_65 : tensor<8xf32> + %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) + return %sqrt : tensor<8xf32> } // CHECK-LABEL: func @map_ops @@ -113,3 +113,27 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te // CHECK-NEXT: linalg.yield %[[RES]] // CHECK-NOT: linalg.map +// ----- + +func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add = linalg.elementwise + kind=#linalg.elementwise_kind + ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) -> tensor<8xf32> + %wqrt = linalg.elementwise + kind=#linalg.elementwise_kind + ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) -> tensor<8xf32> + return %wqrt : tensor<8xf32> +} + +// CHECK-LABEL: func @elementwise_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map From 8d2e8e0be55a1451e8b9774dddf9199158c98b2d Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 18:13:22 -0500 Subject: [PATCH 07/19] fix formatting --- .../lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index fc435b47f5977..c1fc003d3f05d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -223,10 +223,10 @@ static void generateFusedElementwiseOpRegion( auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. - // Since some ops, like `linalg.map`, do not have block arguments for init operands - // then we first "generalize" the block by adding arguments for init operands when - // they aren't present. We detect this case by checking if - // `getOpOperandsMatchingBBargs() == getDpsInputOperands(); + // Since some ops, like `linalg.map`, do not have block arguments for init + // operands then we first "generalize" the block by adding arguments for init + // operands when they aren't present. We detect this case by checking if + // `getOpOperandsMatchingBBargs() == getDpsInputOperands() Block &producerBlock = producer->getRegion(0).front(); if (producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands()) { From 58582bfd75576e2ea089949207e25727cea7ca69 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 18:23:33 -0500 Subject: [PATCH 08/19] use getElementTypeOrSelf for cleanup --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index c1fc003d3f05d..6ec13e33055ce 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -230,22 +230,16 @@ static void generateFusedElementwiseOpRegion( Block &producerBlock = producer->getRegion(0).front(); if (producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands()) { - for (auto init : producer.getDpsInits()) { - Type bbType = isa(init.getType()) - ? cast(init.getType()).getElementType() - : init.getType(); - producerBlock.addArgument(bbType, producer.getLoc()); - } + for (auto init : producer.getDpsInits()) + producerBlock.addArgument(getElementTypeOrSelf(init.getType()), + producer.getLoc()); } Block &consumerBlock = consumer->getRegion(0).front(); if (consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands()) { - for (auto init : consumer.getDpsInits()) { - Type bbType = isa(init.getType()) - ? cast(init.getType()).getElementType() - : init.getType(); - consumerBlock.addArgument(bbType, consumer.getLoc()); - } + for (auto init : consumer.getDpsInits()) + consumerBlock.addArgument(getElementTypeOrSelf(init.getType()), + consumer.getLoc()); } OpBuilder::InsertionGuard guard(rewriter); Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0)); From cf67ab67bf4da9f8c65137fd627b6ba2d8da0ebb Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Jun 2025 18:30:12 -0500 Subject: [PATCH 09/19] switch elementwise test to broadcast version --- .../Dialect/Linalg/fusion-elementwise.mlir | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 2f9011cd5e52b..575f21b8f09f9 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -115,21 +115,25 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te // ----- -func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { - %fill = tensor.empty() : tensor<8xf32> +#identity = affine_map<(d0, d1) -> (d0, d1)> +#bcast = affine_map<(d0, d1) -> (d0)> +func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tensor<8x10xf32> { + %fill = tensor.empty() : tensor<8x10xf32> %add = linalg.elementwise kind=#linalg.elementwise_kind - ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) -> tensor<8xf32> - %wqrt = linalg.elementwise + indexing_maps = [#bcast, #identity, #identity] + ins(%in1, %in2: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) -> tensor<8x10xf32> + %sqrt = linalg.elementwise kind=#linalg.elementwise_kind - ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) -> tensor<8xf32> - return %wqrt : tensor<8xf32> + indexing_maps = [#identity, #identity] + ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) -> tensor<8x10xf32> + return %sqrt : tensor<8x10xf32> } // CHECK-LABEL: func @elementwise_ops // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> // CHECK: %[[FUSED_OP:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): From 7d402c1f75a09d5b9cda01fd49ae287928a47364 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 20 Jun 2025 14:23:35 -0500 Subject: [PATCH 10/19] remove block args that were added (hacky) --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 6ec13e33055ce..c3b5765a5c4ad 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -226,17 +226,21 @@ static void generateFusedElementwiseOpRegion( // Since some ops, like `linalg.map`, do not have block arguments for init // operands then we first "generalize" the block by adding arguments for init // operands when they aren't present. We detect this case by checking if - // `getOpOperandsMatchingBBargs() == getDpsInputOperands() + // `getOpOperandsMatchingBBargs() == getDpsInputOperands()`. + // TODO: This is hacky and should not be merged. Keeping for now for testing + // purposes in the meantime, but need a better way Block &producerBlock = producer->getRegion(0).front(); - if (producer.getOpOperandsMatchingBBargs() == - producer.getDpsInputOperands()) { + bool addOutputArgsProducer = + producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands(); + if (addOutputArgsProducer) { for (auto init : producer.getDpsInits()) producerBlock.addArgument(getElementTypeOrSelf(init.getType()), producer.getLoc()); } Block &consumerBlock = consumer->getRegion(0).front(); - if (consumer.getOpOperandsMatchingBBargs() == - consumer.getDpsInputOperands()) { + bool addOutputArgsConsumer = + consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands(); + if (addOutputArgsConsumer) { for (auto init : consumer.getDpsInits()) consumerBlock.addArgument(getElementTypeOrSelf(init.getType()), consumer.getLoc()); @@ -350,6 +354,14 @@ static void generateFusedElementwiseOpRegion( // Sanity checks. assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() && "Ill-formed GenericOp region"); + // Erase added args in case that the ops are still live after fusion. + // TODO: Remove along with hacky code above. + if (addOutputArgsProducer) + producerBlock.eraseArguments(producer.getNumDpsInputs(), + producer.getNumDpsInits()); + if (addOutputArgsConsumer) + consumerBlock.eraseArguments(consumer.getNumDpsInputs(), + consumer.getNumDpsInits()); } FailureOr From b1d15b2822953882376661a1b66ec8adc5cc01a1 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 21 Jun 2025 11:19:06 -0500 Subject: [PATCH 11/19] add requested tests --- .../Dialect/Linalg/fusion-elementwise.mlir | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 575f21b8f09f9..8aa6974d5f0e4 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -141,3 +141,66 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso // CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] // CHECK-NEXT: linalg.yield %[[SQRT]] // CHECK-NOT: linalg.map + +// ----- + +func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + (%in0 : f32, %in1 : f32) { + %add = arith.addf %in0, %in1 : f32 + %exp = math.exp %add : f32 + linalg.yield %exp : f32 + } + %sqrt_mul = linalg.map ins(%add_exp, %arg2 : tensor<8xf32>, tensor<8xf32>) outs(%fill : tensor<8xf32>) + (%in0 : f32, %in1 : f32) { + %sqrt = math.sqrt %in0 : f32 + %mul = arith.mulf %sqrt, %in1 : f32 + linalg.yield %mul : f32 + } + return %sqrt_mul : tensor<8xf32> +} + +// CHECK-LABEL: func @map_multi_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[ADD]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[EXP]] +// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[SQRT]], %[[IN2]] +// CHECK-NEXT: linalg.yield %[[MUL]] +// CHECK-NOT: linalg.map + +// ----- + +#identity = affine_map<(d0, d1) -> (d0, d1)> +#bcast = affine_map<(d0, d1) -> (d0)> +func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tensor<8x10xf32> { + %fill = tensor.empty() : tensor<8x10xf32> + %add = linalg.generic + {indexing_maps = [#bcast, #identity, #identity], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<8x10xf32> + %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) + return %sqrt : tensor<8x10xf32> +} + +// CHECK-LABEL: func @map_genric_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map From 459bcb47feea8fa75771ec841eadcad96336a430 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 21 Jun 2025 11:24:29 -0500 Subject: [PATCH 12/19] add checks for nontrivial map cases --- mlir/test/Dialect/Linalg/fusion-elementwise.mlir | 6 ++++++ .../lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 8aa6974d5f0e4..d4b25eb4be691 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -130,11 +130,14 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso return %sqrt : tensor<8x10xf32> } +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: func @elementwise_ops // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> // CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] @@ -193,11 +196,14 @@ func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tens return %sqrt : tensor<8x10xf32> } +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: func @map_genric_ops // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> // CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index cb215197253bb..6b9abd34b7781 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -151,9 +151,14 @@ struct TestLinalgElementwiseFusion MLIRContext *context = &this->getContext(); func::FuncOp funcOp = this->getOperation(); + auto controlFn = [](OpOperand *operand) { + auto owner = cast(operand->getOwner()); + auto producer = cast(operand->get().getDefiningOp()); + return (linalg::isElementwise(owner) && linalg::isElementwise(producer)) && (!isa(producer) && !isa(owner)); + }; if (fuseGenericOps) { RewritePatternSet fusionPatterns(context); - auto controlFn = [](OpOperand *operand) { return true; }; + // auto controlFn = [](OpOperand *operand) { return true; }; linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(fusionPatterns)))) From f7e164beebdc8195b6781c2a8ce8bc1bea7757cd Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 21 Jun 2025 11:29:26 -0500 Subject: [PATCH 13/19] revert unintended change --- .../lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 6b9abd34b7781..cb215197253bb 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -151,14 +151,9 @@ struct TestLinalgElementwiseFusion MLIRContext *context = &this->getContext(); func::FuncOp funcOp = this->getOperation(); - auto controlFn = [](OpOperand *operand) { - auto owner = cast(operand->getOwner()); - auto producer = cast(operand->get().getDefiningOp()); - return (linalg::isElementwise(owner) && linalg::isElementwise(producer)) && (!isa(producer) && !isa(owner)); - }; if (fuseGenericOps) { RewritePatternSet fusionPatterns(context); - // auto controlFn = [](OpOperand *operand) { return true; }; + auto controlFn = [](OpOperand *operand) { return true; }; linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(fusionPatterns)))) From 8a375bfc2422122d9cf07db585ef0d6a41322a40 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 1 Aug 2025 11:41:27 -0500 Subject: [PATCH 14/19] fix weird lit failure --- mlir/test/Dialect/Linalg/fusion-elementwise.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index d4b25eb4be691..f712b396148fa 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -130,14 +130,14 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso return %sqrt : tensor<8x10xf32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: func @elementwise_ops // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> // CHECK: %[[FUSED_OP:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]] +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP0]], #[[$MAP0]]] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] @@ -196,14 +196,14 @@ func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tens return %sqrt : tensor<8x10xf32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: func @map_genric_ops // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> // CHECK: %[[FUSED_OP:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]] +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP0]], #[[$MAP0]]] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] From ce33596b10295f7547d6cb442bbc9102fb49a7ae Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 30 Oct 2025 10:51:14 -0500 Subject: [PATCH 15/19] remove hack for linalg.map and update tests --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 28 ------------------- .../Dialect/Linalg/fusion-elementwise.mlir | 4 +-- 2 files changed, 2 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index e7d238dbbac4e..cfdcefc505f39 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -224,28 +224,8 @@ static void generateFusedElementwiseOpRegion( auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. - // Since some ops, like `linalg.map`, do not have block arguments for init - // operands then we first "generalize" the block by adding arguments for init - // operands when they aren't present. We detect this case by checking if - // `getOpOperandsMatchingBBargs() == getDpsInputOperands()`. - // TODO: This is hacky and should not be merged. Keeping for now for testing - // purposes in the meantime, but need a better way Block &producerBlock = producer->getRegion(0).front(); - bool addOutputArgsProducer = - producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands(); - if (addOutputArgsProducer) { - for (auto init : producer.getDpsInits()) - producerBlock.addArgument(getElementTypeOrSelf(init.getType()), - producer.getLoc()); - } Block &consumerBlock = consumer->getRegion(0).front(); - bool addOutputArgsConsumer = - consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands(); - if (addOutputArgsConsumer) { - for (auto init : consumer.getDpsInits()) - consumerBlock.addArgument(getElementTypeOrSelf(init.getType()), - consumer.getLoc()); - } OpBuilder::InsertionGuard guard(rewriter); Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0)); IRMapping mapper; @@ -355,14 +335,6 @@ static void generateFusedElementwiseOpRegion( // Sanity checks. assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() && "Ill-formed GenericOp region"); - // Erase added args in case that the ops are still live after fusion. - // TODO: Remove along with hacky code above. - if (addOutputArgsProducer) - producerBlock.eraseArguments(producer.getNumDpsInputs(), - producer.getNumDpsInits()); - if (addOutputArgsConsumer) - consumerBlock.eraseArguments(consumer.getNumDpsInputs(), - consumer.getNumDpsInits()); } FailureOr diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index f712b396148fa..36180c4c5f5fa 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -89,7 +89,7 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>) - (%in0 : f32, %in1 : f32) { + (%in0 : f32, %in1 : f32, %out : f32) { %cmp = arith.cmpf olt, %in0, %in1 : f32 linalg.yield %cmp : i1 } @@ -150,7 +150,7 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { %fill = tensor.empty() : tensor<8xf32> %add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) - (%in0 : f32, %in1 : f32) { + (%in0 : f32, %in1 : f32, %out : f32) { %add = arith.addf %in0, %in1 : f32 %exp = math.exp %add : f32 linalg.yield %exp : f32 From 1eb34618fe52776d8433acf61528a3edb2b6e934 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 30 Oct 2025 11:33:21 -0500 Subject: [PATCH 16/19] fix tests again --- mlir/test/Dialect/Linalg/fusion-elementwise.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index 36180c4c5f5fa..7946b555e7439 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -89,7 +89,7 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>) - (%in0 : f32, %in1 : f32, %out : f32) { + (%in0 : f32, %in1 : f32, %out : i1) { %cmp = arith.cmpf olt, %in0, %in1 : f32 linalg.yield %cmp : i1 } @@ -156,7 +156,7 @@ func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tens linalg.yield %exp : f32 } %sqrt_mul = linalg.map ins(%add_exp, %arg2 : tensor<8xf32>, tensor<8xf32>) outs(%fill : tensor<8xf32>) - (%in0 : f32, %in1 : f32) { + (%in0 : f32, %in1 : f32, %out : f32) { %sqrt = math.sqrt %in0 : f32 %mul = arith.mulf %sqrt, %in1 : f32 linalg.yield %mul : f32 From 389a9c47106b8e9016250be59bdfb7f289195951 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 2 Nov 2025 08:50:49 -0600 Subject: [PATCH 17/19] add a couple tests with `linalg.matmul` --- .../Linalg/fusion-elementwise-ops.mlir | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 0d9f38389e4a9..817fd7e48a22a 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1038,9 +1038,54 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { // ----- +func.func @map_matmul(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> { + %fill0 = tensor.empty() : tensor<8x10xf32> + %exp = linalg.map {math.exp} ins(%in1 : tensor<8x10xf32>) outs(%fill0: tensor<8x10xf32>) + %fill1 = tensor.empty() : tensor<8x12xf32> + %matmul = linalg.matmul ins(%exp, %in2 : tensor<8x10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32> + return %matmul : tensor<8x12xf32> +} + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @map_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<10x12xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x12xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[IN0]] +// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXP]], %[[IN1]] +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] +// CHECK-NEXT: linalg.yield %[[ADD]] +// CHECK-NOT: linalg.generic + +// ----- + +func.func @matmul_map(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> { + %fill1 = tensor.empty() : tensor<8x12xf32> + %matmul = linalg.matmul ins(%in1, %in2 : tensor<8x10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32> + %exp = linalg.map {math.exp} ins(%matmul : tensor<8x12xf32>) outs(%fill1: tensor<8x12xf32>) + + return %exp : tensor<8x12xf32> +} + +// Should not fuse +// CHECK-LABEL: func @matmul_map +// CHECK-NEXT: tensor.empty +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.map +// CHECK-NEXT: return + +// ----- + // In this test we expect the first two linalg.generic operations to be fused into one, but the third one (the matmul) to remain separate. // The reason is that when the pattern is applied the 1st time, the fusion of the first two operations produces a fused operation with -// an additional result and ana dditional output indexing map that is not a permutation / not invertible. +// an additional result and an additional output indexing map that is not a permutation / not invertible. // The fused op will still produce also the original result (and its output indexing map), which is preserved because the new indexing map // is not invertible. Thus the fused op will have 2 results, but only the 2nd one will be used by the following matmul op as an input argument. // When trying to apply the fusion pattern again, the matmul op won't be fused because the operand to fuse was not produced with an invertible indexing map. From 8c906a0dc73cc6fd67d1fbdbfe451cbc2f0a97d4 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 3 Nov 2025 11:48:24 -0600 Subject: [PATCH 18/19] "fix" mysterious check-mlir error --- mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 817fd7e48a22a..1202a3d198e42 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1046,15 +1046,15 @@ func.func @map_matmul(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor return %matmul : tensor<8x12xf32> } -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @map_matmul // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8x10xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<10x12xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x12xf32> // CHECK: %[[FUSED_OP:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : // CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): From 047266b3031d3e15bd773e785fa4fa1edf4e5cb7 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 5 Nov 2025 09:22:21 -0600 Subject: [PATCH 19/19] use proper getBlock interface methods --- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index cfdcefc505f39..981c38c3a8dec 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -217,15 +217,15 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. static void generateFusedElementwiseOpRegion( - RewriterBase &rewriter, LinalgOp fusedOp, + RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet &preservedProducerResults) { auto producer = cast(fusedOperand->get().getDefiningOp()); auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. - Block &producerBlock = producer->getRegion(0).front(); - Block &consumerBlock = consumer->getRegion(0).front(); + Block &producerBlock = *producer.getBlock(); + Block &consumerBlock = *consumer.getBlock(); OpBuilder::InsertionGuard guard(rewriter); Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0)); IRMapping mapper;