diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 106f0d79d9792..a997502c34299 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -73,6 +73,17 @@ def ApplyTilingCanonicalizationPatternsOp : Op]> { + let description = [{ + Collects patterns to replace linalg.add when destination passing suffices + for achieving the sum. + }]; + + let assemblyFormat = "attr-dict"; +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 48e657cca96e3..cc12ed7cfa6b5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1747,6 +1747,10 @@ void populateFoldReshapeOpsByCollapsingPatterns( void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn); +/// Pattern to replace `linalg.add` when destination passing on a contraction op +/// suffices for achieving the sum. +void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); + /// Pattern to fuse a `tensor.pad` operation with the producer of its source, /// if the producer is a `linalg` operation with all parallel iterator types. void populateFuseTensorPadWithProducerLinalgOpPatterns( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 46c8510f4ed51..3b7b367d3cf2d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -248,6 +248,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns( linalg::populateLinalgTilingCanonicalizationPatterns(patterns); } +void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateFoldAddIntoDestPatterns(patterns); +} + //===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 47af392def94a..b3cd5537aad9b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms ElementwiseToLinalg.cpp EliminateEmptyTensors.cpp EraseUnusedOperandsAndResults.cpp + FoldAddIntoDest.cpp FusePadOpWithLinalgProducer.cpp Fusion.cpp Generalization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp new file mode 100644 index 0000000000000..e940b0787043e --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp @@ -0,0 +1,150 @@ +//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +using namespace mlir; + +// Determine whether the value is defined to be zero. +static bool isDefinedAsZero(Value val) { + if (!val) + return false; + + // Check whether val is a constant scalar / vector splat / tensor splat float + // or integer zero. + if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero())) + return true; + + return TypeSwitch(val.getDefiningOp()) + .Case([&](auto op) { + return op && op.getInputs().size() == 1 && + isDefinedAsZero(op.getInputs()[0]); + }) + .Default([&](auto) { return false; }); +} + +/// Replace a linalg.add with one operand the single user of a contraction, +/// which has a zero-filled, "identity-mapped" destination and is dominated by +/// the `other` operand, by the contraction with `other` as its dest. +/// +/// As an example, the following pseudo-code will be rewritten +/// %cst = arith.constant 0.000000e+00 +/// %empty = tensor.empty() +/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type +/// %C = linalg.matmul ins(%A, %B) outs(%zeroed) +/// %empty2 = tensor.empty() +/// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type +/// %F = linalg.matmul ins(%D, %E) outs(%zeroed2) +/// %out = linalg.add ins(%C, %F) outs(%empty) +/// to: +/// %cst = arith.constant 0.000000e+00 +/// %empty = tensor.empty() +/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type +/// %C = linalg.matmul ins(%A, %B) outs(%zeroed) +/// %out = linalg.matmul ins(%D, %E) outs(%C) +/// +struct FoldAddIntoDest final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::AddOp addOp, + PatternRewriter &rewriter) const override { + // For now, pattern only applies on tensor types (memref support is TODO). + if (!addOp.hasPureTensorSemantics()) + return failure(); + + Value dominatingOperand = nullptr; + linalg::LinalgOp dominatedOp = nullptr; + { // We will forget about which operand was left or right after this block. + Value lhs = addOp.getInputs()[0]; + Value rhs = addOp.getInputs()[1]; + + // Can only put one of addOp's operands in the dest/out arg of the other's + // defining op based on suitable dominance. + // TODO: Can be generalized to move ops around as long as that still + // respects use-def chains and doesn't affect side-effects. + if (auto rhsOp = rhs.getDefiningOp()) { + DominanceInfo domInfo(rhsOp); + if (domInfo.properlyDominates(lhs, rhsOp)) { + dominatingOperand = lhs; + dominatedOp = rhsOp; + } + } + if (auto lhsOp = lhs.getDefiningOp()) { + DominanceInfo domInfo(lhsOp); + if (domInfo.properlyDominates(rhs, lhsOp)) { + dominatingOperand = rhs; + dominatedOp = lhsOp; + } + } + if (!dominatingOperand || !dominatedOp) + return failure(); + // NB: As linalg.add's generalisation ignores the out argument in its + // region there is no need to perform checks on addOp's out argument. + } + + // When dominated op is a contraction we know it accumulates on its out arg. + // E.g., AddOp is not a contraction and hence ignores its out arg's value. + // TODO: Generalize check to also pass in case of other LinalgOps that + // accumulate on their out arg but are not (binary) contraction ops. + auto dominatedDestOp = + dyn_cast((Operation *)dominatedOp); + if (dominatedOp->getNumResults() != 1 || + !linalg::isaContractionOpInterface(dominatedOp) || + (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1)) + return rewriter.notifyMatchFailure( + dominatedOp, "expected dominated op to be single-result " + "destination-passing contraction"); + + // To change the contraction's result, `addOp` must be its only user. + if (!dominatedOp->getResult(0).hasOneUse()) + return rewriter.notifyMatchFailure( + dominatedOp, + "expected linalg.add to be single user of contraction's result"); + + // As `dominatedOp` was already accumulating on its out argument, it is only + // safe to no longer use its current out arg when it is the additive ident. + auto *destOperand = dominatedDestOp.getDpsInitOperand(0); + if (!isDefinedAsZero(destOperand->get())) + return rewriter.notifyMatchFailure( + dominatedOp, "expected dominated op's dest to be additive zero"); + // TODO: If the other op is a contraction and has additive ident as dest, we + // can swap the dests and achieve the proper sum, given suitable dominance. + + // As an operand to `addOp`, `dominatingOperand` has an identity affine_map. + // Hence, we can only substitute `dominatingOperand` for the dest of the + // contraction when dest's indexing_map corresponds to an identity map + // w.r.t. just the dimensions of dest, i.e. is an ordered projection. + SmallVector indexMaps = dominatedOp.getIndexingMapsArray(); + int prevDimPos = -1; + for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) { + auto dim = dyn_cast(expr); + if (!dim || prevDimPos > static_cast(dim.getPosition())) + return rewriter.notifyMatchFailure( + dominatedOp, "expected index_map for contraction's dest to be an " + "ordered projection"); + prevDimPos = dim.getPosition(); + } + + // Replace the additive-ident, i.e. zero, out arg of the dominated op by the + // dominating summand. This makes the dominated op's result the sum of both + // of addOp's arguments - therefore we replace addOp and it uses by it. + rewriter.modifyOpInPlace( + dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); }); + rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0)); + return success(); + } +}; + +void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) { + // Replace linalg.add when destination passing suffices for achieving the sum. + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir new file mode 100644 index 0000000000000..d8e92e40739dc --- /dev/null +++ b/mlir/test/Dialect/Linalg/fold-add-into-dest.mlir @@ -0,0 +1,329 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!type = tensor<2048x2048xf32> +func.func @fold_add_on_two_matmuls(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type + %4 = tensor.empty() : !type + %5 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %6 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%5 : !type) -> !type + %7 = linalg.add ins(%3, %6 : !type, !type) outs(%1 : !type) -> !type + return %7 : !type +} + +// CHECK-LABEL: func.func @fold_add_on_two_matmuls( +// CHECK-SAME: %[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +// CHECK-NEXT: %[[DENSE:.*]] = arith.constant dense<1.11 +// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 +// CHECK-NEXT: %[[EMPTY:.*]] = tensor.empty() +// CHECK-NEXT: %[[FILLED:.*]] = linalg.fill ins(%[[ZERO]] : {{.*}}) outs(%[[EMPTY]] : {{.*}}) +// CHECK-NEXT: %[[ACC:.+]] = linalg.matmul ins(%[[ARG0]], %[[DENSE]] : {{.*}}) outs(%[[FILLED]] : {{.*}}) +// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins(%[[ARG1]], %[[DENSE]] : {{.*}}) outs(%[[ACC]] : {{.*}}) +// CHECK-NOT: linalg.add +// CHECK-NEXT: return %[[RES]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} +// ----- + +!type = tensor<2048x2048xf32> +func.func @expect_no_fold_of_add_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type + %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type + %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type + return %5 : !type +} + +// CHECK-LABEL: func.func @expect_no_fold_of_add_as_orig_dest_not_additive_zero +// CHECK: linalg.fill +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.add +// CHECK-NEXT: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +!type = tensor<2048x2048xf32> +func.func @expect_no_fold_of_add_as_contraction_result_has_multiple_users(%arg0: !type, %arg1: !type) -> (!type, !type) { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type + %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%0 : !type) -> !type + %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type + %6 = linalg.mul ins(%4, %arg0 : !type, !type) outs(%1 : !type) -> !type + return %5, %6 : !type, !type +} + +// CHECK-LABEL: func.func @expect_no_fold_of_add_as_contraction_result_has_multiple_users +// CHECK: linalg.fill +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.add +// CHECK-NEXT: linalg.mul +// CHECK-NEXT: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +!type = tensor<2048x2048xf32> +func.func @fold_add_on_matmul_and_func_arg(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type + %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type + return %5 : !type +} + +// CHECK-LABEL: func.func @fold_add_on_matmul_and_func_arg +// CHECK: %[[RES:.+]] = linalg.matmul +// CHECK-NOT: linalg.add +// CHECK-NEXT: return %[[RES]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +!type = tensor<2048x2048xf32> +func.func @expect_no_fold_of_add_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type + %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type + return %4 : !type +} + +// CHECK-LABEL: func.func @expect_no_fold_of_add_as_operands_do_not_dominate_each_other +// CHECK: linalg.fill +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.add +// CHECK-NEXT: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +!type = tensor<2048x2048xf32> +func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type + %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type + %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type + return %5 : !type +} + +// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls +// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a +// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]] +// CHECK-NOT: linalg.add +// CHECK-NEXT: return %[[RES]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +!type = tensor<2048x2048xf32> +func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type + %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type + %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type + return %5 : !type +} + +// CHECK-LABEL: func.func @expect_no_fold_of_add_as_dominated_op_is_not_a_contraction +// CHECK: linalg.fill +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.sub +// CHECK-NEXT: linalg.add +// CHECK-NEXT: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0)> // NB: not an ordered projection + +!type = tensor<2048x2048xf32> +func.func @expect_no_fold_of_add_as_dest_accumulation_is_not_identity_mapped(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"] } + ins(%arg0, %0: !type, !type) outs(%2: !type) { + ^bb0(%a: f32, %b: f32, %c: f32): + %5 = arith.mulf %a, %b : f32 + %6 = arith.addf %c, %5 : f32 + linalg.yield %6 : f32 + } -> !type + %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type + return %4 : !type +} + +// CHECK-LABEL: func.func @expect_no_fold_of_add_as_dest_accumulation_is_not_identity_mapped +// CHECK: linalg.fill +// CHECK-NEXT: linalg.generic +// CHECK: linalg.add +// CHECK-NEXT: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> // NB: is an ordered projection + +!type = tensor<2048x2048xf32> +func.func @fold_add_on_a_generic_and_an_argument(%arg0: !type, %arg1: !type) -> !type { + %0 = arith.constant dense<1.111111e+00> : !type + %cst = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : !type + %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type + %3 = linalg.generic { indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"] } + ins(%arg0, %0: !type, !type) outs(%2: !type) { + ^bb0(%a: f32, %b: f32, %c: f32): + %5 = arith.mulf %a, %b : f32 + %6 = arith.addf %c, %5 : f32 + linalg.yield %6 : f32 + } -> !type + %4 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type + return %4 : !type +} + +// CHECK-LABEL: func.func @fold_add_on_a_generic_and_an_argument +// CHECK: linalg.generic +// CHECK-NOT: linalg.add +// CHECK: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +} + +// ----- + +memref.global "private" constant @big_const : memref<2048x2048xf32> = dense<1.11111104> {alignment = 64 : i64} +func.func @expect_no_fold_due_to_no_memref_support(%arg0: memref<2048x2048xf32>, %arg1: memref<2048x2048xf32>) -> memref<2048x2048xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.get_global @big_const : memref<2048x2048xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<2048x2048xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2048x2048xf32> + linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<2048x2048xf32>) + linalg.matmul ins(%arg0, %0 : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc_0 : memref<2048x2048xf32>) + linalg.fill ins(%cst : f32) outs(%alloc : memref<2048x2048xf32>) + linalg.matmul ins(%arg1, %0 : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc : memref<2048x2048xf32>) + linalg.add ins(%alloc_0, %alloc : memref<2048x2048xf32>, memref<2048x2048xf32>) outs(%alloc : memref<2048x2048xf32>) + memref.dealloc %alloc_0 : memref<2048x2048xf32> + return %alloc : memref<2048x2048xf32> +} + +// CHECK-LABEL: func.func @expect_no_fold_due_to_no_memref_support +// CHECK: linalg.matmul +// CHECK: linalg.matmul +// CHECK: linalg.add +// CHECK: return + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.linalg.fold_add_into_dest + } : !transform.any_op + transform.yield + } +}