Skip to content

[mlir][linalg] Morphism across linalg named, category and generic ops. #148424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,45 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
];
}

def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
let summary = "Convert named op to category ops or generic and vice-versa";

let description = [{
Convert a linalg op from one representation to another equivalent.
For example, a linalg named op `linalg.add` can also be written as an
category op `linalg.elementwise`, and can also be re-written as
a `linalg.generic`, giving the morphism:

named-op <--> category_op (elementwise, contraction, ..) <--> generic

Note that the set of `linalg.generic` subsumes named and category ops
and therefore not all `linalg.genric` can be converted to named or
category op. Similarly, catgory ops subsume named ops.

Note:
Legacy converters:
`--linalg-generalize-named-ops` is the path `named-op --> generic-op`
`--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
}];
let dependentDialects = ["linalg::LinalgDialect"];

let options = [
// named-op <--> category <--> generic

// Lowering options
Option<"namedToCategory", "named-to-category", "bool", /*default=*/"false",
"convert named ops to category op e.g. `linalg.elementwise`">,
Option<"categoryToGeneric", "category-to-generic", "bool", /*default=*/"false",
"convert category ops e.g. `linalg.elementwise` to `linalg.generic`">,
Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false",
"convert named ops e.g. `linalg.add` to `linalg.generic`">,

// Lifting options
// TODOs: `generic-to-category`, `category-to-named`
Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
"convert linalg.generic to equivalent named ops"> ];
}

def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let summary = "Convert named ops into generic ops";
let dependentDialects = ["linalg::LinalgDialect"];
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);

/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
/// to equivalent `linalg.elementwise`.
void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that fold operations like
/// `linalg.transform` into elementwise op map.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
MorphOps.cpp
TransposeMatmul.cpp
ShardingInterfaceImpl.cpp
NamedOpConversions.cpp
NamedToElementwise.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
Expand Down
62 changes: 62 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/MorphOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===- MorphOps.cpp - conversion between named,category and generic ops ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements conversions between linalg ops:
// named <--> category (elementwise, contraction, ..) <--> generic.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGMORPHOPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "linalg-morphism"

using namespace mlir;
using namespace mlir::linalg;

namespace {
struct LinalgMorphOpsPass
: public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> {

using impl::LinalgMorphOpsPassBase<
LinalgMorphOpsPass>::LinalgMorphOpsPassBase;

void runOnOperation() override;
};

void LinalgMorphOpsPass::runOnOperation() {

RewritePatternSet patterns(&getContext());

// Lowering paths (named -> category -> generic)
if (namedToCategory) {
populateLinalgNamedToElementwisePatterns(patterns);
}
if (namedToGeneric || categoryToGeneric) {
populateLinalgNamedOpsGeneralizationPatterns(patterns);
}

// Lifting paths (named <- category <- generic)
if (genericToNamed) {
populateLinalgGenericOpsSpecializationPatterns(patterns);
}

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
} // namespace
98 changes: 98 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting those linalg named ops that are essentially
// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
// optimization on `linalg.elementwise` such as folding transpose, broadcast.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::linalg;

#define DEBUG_TYPE "linalg-named-to-elementwise"

namespace {
ElementwiseKind getKind(Operation *op) {
return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
.Case([](SelectOp) { return ElementwiseKind::select; })
.Case([](AddOp) { return ElementwiseKind::add; })
.Case([](SubOp) { return ElementwiseKind::sub; })
.Case([](MulOp) { return ElementwiseKind::mul; })
.Case([](DivOp) { return ElementwiseKind::div; })
.Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
.Case([](PowFOp) { return ElementwiseKind::powf; })
.Case([](ExpOp) { return ElementwiseKind::exp; })
.Case([](LogOp) { return ElementwiseKind::log; })
.Case([](AbsOp) { return ElementwiseKind::abs; })
.Case([](CeilOp) { return ElementwiseKind::ceil; })
.Case([](FloorOp) { return ElementwiseKind::floor; })
.Case([](NegFOp) { return ElementwiseKind::negf; })
.Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
.Case([](RoundOp) { return ElementwiseKind::round; })
.Case([](SqrtOp) { return ElementwiseKind::sqrt; })
.Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
.Case([](SquareOp) { return ElementwiseKind::square; })
.Case([](TanhOp) { return ElementwiseKind::tanh; })
.Case([](ErfOp) { return ElementwiseKind::erf; })
.Default([&](Operation *op) {
llvm_unreachable("unhandled case in named to elementwise");
return ElementwiseKind::sub;
});
}

template <typename NamedOpTy>
struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
using OpRewritePattern<NamedOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(NamedOpTy op,
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute> attrs;
auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
attrs.push_back(
rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));

rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
op.getDpsInits(), attrs);
return success();
}
};
} // namespace

void mlir::linalg::populateLinalgNamedToElementwisePatterns(
RewritePatternSet &patterns) {
patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
}
56 changes: 56 additions & 0 deletions mlir/test/Dialect/Linalg/elementwise/named-to-elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category -split-input-file | FileCheck %s

// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
return %exp : tensor<16x8xf32>
}

// ----

// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<add>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
%add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
return %add : tensor<16x8xf32>
}

// ----

// CHECK: @sub(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>)
//
func.func @sub(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
%sub = linalg.sub ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
return %sub : tensor<16x8xf32>
}

// ----

// CHECK: @ternary_select(%[[A:.+]]: tensor<4x8x16xi1>, %[[B:.+]]: tensor<4x8x16xf32>, %[[C:.+]]: tensor<4x8x16xf32>)
// CHECK: %[[E:.+]] = tensor.empty() : tensor<4x8x16xf32>
// CHECK: {{.*}} = linalg.elementwise
// CHECK-SAME: kind=#linalg.elementwise_kind<select>
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
// CHECK-SAME: outs(%[[E]] : tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
//
func.func @ternary_select(%A: tensor<4x8x16xi1>, %B: tensor<4x8x16xf32>, %C: tensor<4x8x16xf32>)
-> tensor<4x8x16xf32> {
%empty = tensor.empty() : tensor<4x8x16xf32>
%select = linalg.select
ins(%A, %B, %C : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
outs(%empty: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
return %select : tensor<4x8x16xf32>
}
15 changes: 15 additions & 0 deletions mlir/test/Dialect/Linalg/linalg-morph-category-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Forward path `named -> category -> generic`
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | FileCheck %s --check-prefix=NAMED_TO_CATEGORY

// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | \
// RUN: mlir-opt %s -linalg-morph-ops=category-to-generic | FileCheck %s --check-prefix=CATEGORY_TO_GENERIC

func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
return %exp : tensor<16x8xf32>
}
// NAMED_TO_CATEGORY: linalg.elementwise
// NAMED_TO_CATEGORY-NOT: linalg.exp

// CATEGORY_TO_GENERIC: linalg.generic
// CATEGORY_TO_GENERIC-NOT: linalg.elementwise
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Linalg/linalg-morph-multi-step.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | FileCheck %s --check-prefix=NAMED_TO_GENERIC
// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | mlir-opt %s -linalg-morph-ops=generic-to-named | \
// RUN: FileCheck %s --check-prefix=ROUND_TRIP

func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
return %exp : tensor<16x8xf32>
}

// NAMED_TO_GENERIC: linalg.generic
// NAMED_TO_GENERIC-NOT: linalg.exp

// ROUND_TRIP: linalg.exp
// ROUND_TRIP-NOT: linalg.generic