diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 1dc42f71e10ef..4b325aaeab87c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); /// convert to a `linalg.dot`. void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); +/// Add patterns to fuse a linalg fill operation with a linalg operation. +/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused +/// linalg.generic operation. +/// The fill operation is expected to happen only on the first index +/// of the reduction dimension. Currently only one reduction dimension is +/// supported. Given the pattern: +/// %empty = tensor.empty() : tensor +/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor) -> +/// tensor %reduced = linalg.reduce ins(%0 : tensor<147456xi8>) +/// outs(%filled : tensor) dimensions = [0] +/// (%in: i8, %init: i8) { +/// %3 = arith.addi %in, %init : i8 +/// linalg.yield %3 : i8 +/// } +/// The pattern is rewritten into: +/// %empty = tensor.empty() : tensor +/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty : +/// tensor) { +/// ^bb0(%in: i8, %init: i8): +/// %cst = arith.constant 0 : index +/// %index = linalg.index %c0 : index +/// %cmp = arith.cmpi eq, %cst, %index : i1 +/// %sum = arith.select %cmp, %c0, %init : i8 +/// %res = arith.addi %in, %sum : i8 +/// linalg.yield %res : i8 +/// } +void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 3594b08413812..cace3dcb6cbfc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms EraseUnusedOperandsAndResults.cpp FoldAddIntoDest.cpp FusePadOpWithLinalgProducer.cpp + FuseFillOpWithReduceOp.cpp Fusion.cpp Generalization.cpp Hoisting.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp b/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp new file mode 100644 index 0000000000000..6811bbbb63e22 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/FuseFillOpWithReduceOp.cpp @@ -0,0 +1,107 @@ +//===- FuseFillOpWithReduceOp.cpp - Fuse linalg fill with reduce producer -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns that fuses a linalg.generic -> tensor.pad op +// chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice +// op chain. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +/// Fold linalg.fill into linalg.reduce by creating a fused linalg.generic +/// operation. The fill operation is expected to happen only on the first index +/// of the reduction dimension. Currently only one reduction dimension is +/// supported. Given the pattern: +/// %empty = tensor.empty() : tensor +/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor) -> +/// tensor %reduced = linalg.reduce ins(%0 : tensor<147456xi8>) +/// outs(%filled : tensor) dimensions = [0] +/// (%in: i8, %init: i8) { +/// %3 = arith.addi %in, %init : i8 +/// linalg.yield %3 : i8 +/// } +/// The pattern is rewritten into: +/// %empty = tensor.empty() : tensor +/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty : +/// tensor) { +/// ^bb0(%in: i8, %init: i8): +/// %cst = arith.constant 0 : index +/// %index = linalg.index %c0 : index +/// %cmp = arith.cmpi eq, %cst, %index : i1 +/// %sum = arith.select %cmp, %c0, %init : i8 +/// %res = arith.addi %in, %sum : i8 +/// linalg.yield %res : i8 +/// } +struct FoldFillWithReduceOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp, + PatternRewriter &rewriter) const override { + if (!reduceOp.hasPureTensorSemantics()) + return rewriter.notifyMatchFailure( + reduceOp, "skip reduce op with non-pure tensor semantics"); + if (reduceOp.getDimensions().size() != 1) + return rewriter.notifyMatchFailure( + reduceOp, "skip reduce op with non-single dimension"); + if (reduceOp.getNumDpsInputs() != 1 || reduceOp.getNumDpsInits() != 1) + return rewriter.notifyMatchFailure( + reduceOp, "skip reduce op with multiple number of inputs/results"); + auto fillOp = reduceOp.getInits()[0].getDefiningOp(); + if (!fillOp) + return rewriter.notifyMatchFailure( + reduceOp, + "skip reduce op with inits not directly based on fill operation"); + + long dim = reduceOp.getDimensions()[0]; + // Note: on success, the `reduceOp` is replaced with a genericOp and no + // longer valid. + auto failureOrGenericOp = linalg::generalizeNamedOp(rewriter, reduceOp); + if (failed(failureOrGenericOp)) + return rewriter.notifyMatchFailure(reduceOp, + "failed to generalize reduce op"); + + linalg::GenericOp genericReduceOp = *failureOrGenericOp; + auto operandIdx = -1; + for (auto &use : genericReduceOp->getOpOperands()) { + if (use.get().getDefiningOp() == fillOp) + operandIdx = use.getOperandNumber(); + } + assert(operandIdx != -1 && "fill op not found in reduce op uses"); + + Location loc = genericReduceOp.getLoc(); + auto blockArg = genericReduceOp.getMatchingBlockArgument( + &genericReduceOp->getOpOperand(operandIdx)); + rewriter.setInsertionPointToStart(genericReduceOp.getBody()); + auto constZeroIndexOp = rewriter.create(loc, 0); + auto linalgIndexOp = rewriter.create(loc, dim); + auto cmpIOp = rewriter.create(loc, arith::CmpIPredicate::eq, + constZeroIndexOp.getResult(), + linalgIndexOp.getResult()); + auto selectOp = + rewriter.create(loc, cmpIOp, fillOp.value(), blockArg); + rewriter.replaceAllUsesExcept(blockArg, selectOp.getResult(), selectOp); + genericReduceOp->setOperand(operandIdx, fillOp.getDpsInitOperand(0)->get()); + + return success(); + } +}; + +} // namespace + +void mlir::linalg::populateFuseFillOpWithReduceOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir b/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir new file mode 100644 index 0000000000000..7721cfec72ce7 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fuse_fill_op_with_reduce_op.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt -test-linalg-fuse-fill-op-with-reduce-op -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func private @test_reduce_sum_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor { +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8 +// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor +// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor) { +// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8): +// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index +// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8 +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : i8 +// CHECK: linalg.yield %[[VAL_10]] : i8 +// CHECK: } -> tensor +// CHECK: return %[[VAL_11:.*]] : tensor +// CHECK: } + +func.func private @test_reduce_sum_kernel(%arg0: tensor<147456xi8>) -> (tensor) { + %1 = tensor.empty() : tensor + %c0_i8 = arith.constant 0 : i8 + %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor) -> tensor + %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor) dimensions = [0] + (%in: i8, %init: i8) { + %3 = arith.addi %in, %init : i8 + linalg.yield %3 : i8 + } + return %reduced : tensor +} + +// ----- + +func.func private @test_missing_fill(%arg0: tensor<147456xi8>) -> (tensor) { + %1 = tensor.empty() : tensor + // CHECK: linalg.reduce + %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%1 : tensor) dimensions = [0] + (%in: i8, %init: i8) { + %3 = arith.addi %in, %init : i8 + linalg.yield %3 : i8 + } + return %reduced : tensor +} + +// ----- + +// CHECK-LABEL: func.func private @test_reduce_multiply_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor { +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i8 +// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor +// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor) { +// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8): +// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index +// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8 +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_5]], %[[VAL_9]] : i8 +// CHECK: linalg.yield %[[VAL_10]] : i8 +// CHECK: } -> tensor +// CHECK: return %[[VAL_11:.*]] : tensor +// CHECK: } + +func.func private @test_reduce_multiply_kernel(%arg0: tensor<147456xi8>) -> (tensor) { + %1 = tensor.empty() : tensor + %c1_i8 = arith.constant 1 : i8 + %2 = linalg.fill ins(%c1_i8 : i8) outs(%1 : tensor) -> tensor + %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor) dimensions = [0] + (%in: i8, %init: i8) { + %3 = arith.muli %in, %init : i8 + linalg.yield %3 : i8 + } + return %reduced : tensor +} + +// ----- + +func.func private @test_reduce_sum_on_multiple_dims(%arg0: tensor<2x147456xi8>) -> (tensor) { + %1 = tensor.empty() : tensor + %c0_i8 = arith.constant 0 : i8 + // CHECK: linalg.fill + %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor) -> tensor + // CHECK: linalg.reduce + %reduced = linalg.reduce ins(%arg0 : tensor<2x147456xi8>) outs(%2 : tensor) dimensions = [0, 1] + (%in: i8, %init: i8) { + %3 = arith.addi %in, %init : i8 + linalg.yield %3 : i8 + } + return %reduced : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt index eb6f581252181..2c2cef6042874 100644 --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp TestLinalgRankReduceContractionOps.cpp + TestLinalgFuseFillOpWithReduceOp.cpp TestLinalgTransforms.cpp TestPadFusion.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp new file mode 100644 index 0000000000000..d5506cae74161 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseFillOpWithReduceOp.cpp @@ -0,0 +1,63 @@ +//===- TestLinalgFuseFillOpWithReduceOp.cpp -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing fuse linalg fill with linalg reduce +// into a new linalg generic operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TestLinalgFuseFillOpWithReduceOp + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseFillOpWithReduceOp) + + TestLinalgFuseFillOpWithReduceOp() = default; + TestLinalgFuseFillOpWithReduceOp( + const TestLinalgFuseFillOpWithReduceOp &pass) = default; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { + return "test-linalg-fuse-fill-op-with-reduce-op"; + } + StringRef getDescription() const final { + return "Test fuse linalg fill with linalg reduce into a new linalg generic " + "operation"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + func::FuncOp funcOp = this->getOperation(); + + RewritePatternSet patterns(context); + linalg::populateFuseFillOpWithReduceOpPatterns(patterns); + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgFuseFillOpWithReduceOp() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 74007d01347ae..7e92095ff2fae 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgGreedyFusion(); void registerTestLinalgRankReduceContractionOps(); +void registerTestLinalgFuseFillOpWithReduceOp(); void registerTestLinalgTransforms(); void registerTestLivenessAnalysisPass(); void registerTestLivenessPass(); @@ -251,6 +252,7 @@ void registerTestPasses() { mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgGreedyFusion(); mlir::test::registerTestLinalgRankReduceContractionOps(); + mlir::test::registerTestLinalgFuseFillOpWithReduceOp(); mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessAnalysisPass(); mlir::test::registerTestLivenessPass();