diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 373842c9b03de..f23662930accc 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -89,6 +89,45 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> { ]; } +def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> { + let summary = "Convert named op to category ops or generic and vice-versa"; + + let description = [{ + Convert a linalg op from one representation to another equivalent. + For example, a linalg named op `linalg.add` can also be written as an + category op `linalg.elementwise`, and can also be re-written as + a `linalg.generic`, giving the morphism: + + named-op <--> category_op (elementwise, contraction, ..) <--> generic + + Note that the set of `linalg.generic` subsumes named and category ops + and therefore not all `linalg.genric` can be converted to named or + category op. Similarly, catgory ops subsume named ops. + + Note: + Legacy converters: + `--linalg-generalize-named-ops` is the path `named-op --> generic-op` + `--linalg-specialize-generic-ops` is the path `named-op <-- generic-op` + }]; + let dependentDialects = ["linalg::LinalgDialect"]; + + let options = [ + // named-op <--> category <--> generic + + // Lowering options + Option<"namedToCategory", "named-to-category", "bool", /*default=*/"false", + "convert named ops to category op e.g. `linalg.elementwise`">, + Option<"categoryToGeneric", "category-to-generic", "bool", /*default=*/"false", + "convert category ops e.g. `linalg.elementwise` to `linalg.generic`">, + Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false", + "convert named ops e.g. `linalg.add` to `linalg.generic`">, + + // Lifting options + // TODOs: `generic-to-category`, `category-to-named` + Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false", + "convert linalg.generic to equivalent named ops"> ]; +} + def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> { let summary = "Convert named ops into generic ops"; let dependentDialects = ["linalg::LinalgDialect"]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index d4ffe0a91fcfe..1e5b5d46de55f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1831,6 +1831,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); void populateLinalgGenericOpsSpecializationPatterns( RewritePatternSet &patterns); +/// Populates `patterns` that convert linalg named ops e.g. `linalg.add` +/// to equivalent `linalg.elementwise`. +void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns); + /// Populates `patterns` with patterns that fold operations like /// `linalg.transform` into elementwise op map. void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 70f846e5bbd20..6ec2e9fd0be7d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms InlineScalarOperands.cpp Interchange.cpp Loops.cpp + MorphOps.cpp TransposeMatmul.cpp ShardingInterfaceImpl.cpp NamedOpConversions.cpp + NamedToElementwise.cpp BlockPackMatmul.cpp PackAndUnpackPatterns.cpp Padding.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp new file mode 100644 index 0000000000000..f261ccb1415fe --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp @@ -0,0 +1,62 @@ +//===- MorphOps.cpp - conversion between named,category and generic ops ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements conversions between linalg ops: +// named <--> category (elementwise, contraction, ..) <--> generic. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_LINALGMORPHOPSPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "linalg-morphism" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +struct LinalgMorphOpsPass + : public impl::LinalgMorphOpsPassBase { + + using impl::LinalgMorphOpsPassBase< + LinalgMorphOpsPass>::LinalgMorphOpsPassBase; + + void runOnOperation() override; +}; + +void LinalgMorphOpsPass::runOnOperation() { + + RewritePatternSet patterns(&getContext()); + + // Lowering paths (named -> category -> generic) + if (namedToCategory) { + populateLinalgNamedToElementwisePatterns(patterns); + } + if (namedToGeneric || categoryToGeneric) { + populateLinalgNamedOpsGeneralizationPatterns(patterns); + } + + // Lifting paths (named <- category <- generic) + if (genericToNamed) { + populateLinalgGenericOpsSpecializationPatterns(patterns); + } + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); +} +} // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp new file mode 100644 index 0000000000000..00a076b6e9746 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp @@ -0,0 +1,98 @@ +//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewriting those linalg named ops that are essentially +// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further +// optimization on `linalg.elementwise` such as folding transpose, broadcast. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::linalg; + +#define DEBUG_TYPE "linalg-named-to-elementwise" + +namespace { +ElementwiseKind getKind(Operation *op) { + return llvm::TypeSwitch(op) + .Case([](SelectOp) { return ElementwiseKind::select; }) + .Case([](AddOp) { return ElementwiseKind::add; }) + .Case([](SubOp) { return ElementwiseKind::sub; }) + .Case([](MulOp) { return ElementwiseKind::mul; }) + .Case([](DivOp) { return ElementwiseKind::div; }) + .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; }) + .Case([](PowFOp) { return ElementwiseKind::powf; }) + .Case([](ExpOp) { return ElementwiseKind::exp; }) + .Case([](LogOp) { return ElementwiseKind::log; }) + .Case([](AbsOp) { return ElementwiseKind::abs; }) + .Case([](CeilOp) { return ElementwiseKind::ceil; }) + .Case([](FloorOp) { return ElementwiseKind::floor; }) + .Case([](NegFOp) { return ElementwiseKind::negf; }) + .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; }) + .Case([](RoundOp) { return ElementwiseKind::round; }) + .Case([](SqrtOp) { return ElementwiseKind::sqrt; }) + .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; }) + .Case([](SquareOp) { return ElementwiseKind::square; }) + .Case([](TanhOp) { return ElementwiseKind::tanh; }) + .Case([](ErfOp) { return ElementwiseKind::erf; }) + .Default([&](Operation *op) { + llvm_unreachable("unhandled case in named to elementwise"); + return ElementwiseKind::sub; + }); +} + +template +struct NamedToElementwisePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(NamedOpTy op, + PatternRewriter &rewriter) const override { + SmallVector attrs; + auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op)); + attrs.push_back(rewriter.getNamedAttr("kind", kindAttr)); + attrs.push_back( + rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps())); + + rewriter.replaceOpWithNewOp(op, op.getDpsInputs(), + op.getDpsInits(), attrs); + return success(); + } +}; +} // namespace + +void mlir::linalg::populateLinalgNamedToElementwisePatterns( + RewritePatternSet &patterns) { + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir new file mode 100644 index 0000000000000..2332b287ace8d --- /dev/null +++ b/mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt %s -linalg-morph-ops=named-to-category -split-input-file | FileCheck %s + +// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> { +// CHECK: {{.*}} = linalg.elementwise +// CHECK-SAME: kind=#linalg.elementwise_kind +// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>) +// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32> +// +func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> { + %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32> + return %exp : tensor<16x8xf32> +} + +// ---- + +// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> { +// CHECK: {{.*}} = linalg.elementwise +// CHECK-SAME: kind=#linalg.elementwise_kind +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>) +// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32> +// +func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> { + %add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32> + return %add : tensor<16x8xf32> +} + +// ---- + +// CHECK: @sub(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> { +// CHECK: {{.*}} = linalg.elementwise +// CHECK-SAME: kind=#linalg.elementwise_kind +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>) +// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) +// +func.func @sub(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> { + %sub = linalg.sub ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32> + return %sub : tensor<16x8xf32> +} + +// ---- + +// CHECK: @ternary_select(%[[A:.+]]: tensor<4x8x16xi1>, %[[B:.+]]: tensor<4x8x16xf32>, %[[C:.+]]: tensor<4x8x16xf32>) +// CHECK: %[[E:.+]] = tensor.empty() : tensor<4x8x16xf32> +// CHECK: {{.*}} = linalg.elementwise +// CHECK-SAME: kind=#linalg.elementwise_kind