Skip to content

Commit 0303b3c

Browse files
committed
[mlir] Add SelectPass
`SelectPass` allows to dynamically select the pass pipeline based on attribute value attached to some top-level op.
1 parent 7dd5f23 commit 0303b3c

File tree

5 files changed

+184
-0
lines changed

5 files changed

+184
-0
lines changed

mlir/include/mlir/Transforms/Passes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
4646
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
4747
#define GEN_PASS_DECL_TOPOLOGICALSORT
4848
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
49+
#define GEN_PASS_DECL_SELECTPASS
4950
#include "mlir/Transforms/Passes.h.inc"
5051

5152
/// Creates an instance of the Canonicalizer pass, configured with default
@@ -139,6 +140,13 @@ std::unique_ptr<Pass> createCompositeFixedPointPass(
139140
std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
140141
int maxIterations = 10);
141142

143+
/// Creates select pass which allows to run multiple different set of passes
144+
/// based on attribute value on some top-level op.
145+
std::unique_ptr<Pass> createSelectPass(
146+
std::string name, std::string selectCondName,
147+
ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
148+
populateFuncs);
149+
142150
//===----------------------------------------------------------------------===//
143151
// Registration
144152
//===----------------------------------------------------------------------===//

mlir/include/mlir/Transforms/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,4 +586,23 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> {
586586
];
587587
}
588588

589+
def SelectPass : Pass<"select-pass"> {
590+
let summary = "Select pass";
591+
let description = [{
592+
Select pass allows to run multiple different set of passes based on
593+
attribute value on some top-level op.
594+
}];
595+
596+
let options = [
597+
Option<"name", "name", "std::string", /*default=*/"\"SelectPass\"",
598+
"Select pass display name">,
599+
Option<"selectCondName", "select-cond-name", "std::string", "\"select\"",
600+
"Attribute name used for condition">,
601+
ListOption<"selectValues", "select-values", "std::string",
602+
"Values used to check select condition">,
603+
ListOption<"selectPipelines", "select-pipelines", "std::string",
604+
"Pipelines, assotiated with corresponding select values">,
605+
];
606+
}
607+
589608
#endif // MLIR_TRANSFORMS_PASSES

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
1414
PrintIR.cpp
1515
RemoveDeadValues.cpp
1616
SCCP.cpp
17+
SelectPass.cpp
1718
SROA.cpp
1819
StripDebugInfo.cpp
1920
SymbolDCE.cpp

mlir/lib/Transforms/SelectPass.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//===- SelectPass.cpp - Select pass code ----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// SelectPass allows to run multiple different set of passes based on attribute
10+
// value on some top-level op.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Transforms/Passes.h"
15+
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Pass/PassManager.h"
18+
19+
namespace mlir {
20+
#define GEN_PASS_DEF_SELECTPASS
21+
#include "mlir/Transforms/Passes.h.inc"
22+
} // namespace mlir
23+
24+
using namespace mlir;
25+
26+
namespace {
27+
struct SelectPass final : public impl::SelectPassBase<SelectPass> {
28+
using SelectPassBase::SelectPassBase;
29+
30+
SelectPass(
31+
std::string name_, std::string selectCondName_,
32+
ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
33+
populateFuncs) {
34+
name = std::move(name_);
35+
selectCondName = std::move(selectCondName_);
36+
37+
SmallVector<std::string> selectVals;
38+
SmallVector<std::string> selectPpls;
39+
selectVals.reserve(populateFuncs.size());
40+
selectPpls.reserve(populateFuncs.size());
41+
selectPassManagers.reserve(populateFuncs.size());
42+
for (auto &&[name, populate] : populateFuncs) {
43+
selectVals.emplace_back(name);
44+
45+
auto &pm = selectPassManagers.emplace_back();
46+
populate(pm);
47+
48+
llvm::raw_string_ostream os(selectPpls.emplace_back());
49+
pm.printAsTextualPipeline(os);
50+
}
51+
52+
selectValues = selectVals;
53+
selectPipelines = selectPpls;
54+
}
55+
56+
LogicalResult initializeOptions(
57+
StringRef options,
58+
function_ref<LogicalResult(const Twine &)> errorHandler) override {
59+
if (failed(SelectPassBase::initializeOptions(options, errorHandler)))
60+
return failure();
61+
62+
if (selectCondName.empty())
63+
return errorHandler("Invalid select-cond-name");
64+
65+
if (selectValues.size() != selectPipelines.size())
66+
return errorHandler("Values and pipelines size mismatch");
67+
68+
selectPassManagers.resize(selectPipelines.size());
69+
70+
for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) {
71+
if (failed(parsePassPipeline(pipeline, selectPassManagers[i])))
72+
return errorHandler("Failed to parse pipeline");
73+
}
74+
75+
return success();
76+
}
77+
78+
LogicalResult initialize(MLIRContext *context) override {
79+
condAttrName = StringAttr::get(context, selectCondName);
80+
81+
selectAttrs.reserve(selectAttrs.size());
82+
for (StringRef value : selectValues)
83+
selectAttrs.emplace_back(StringAttr::get(context, value));
84+
85+
return success();
86+
}
87+
88+
void getDependentDialects(DialectRegistry &registry) const override {
89+
for (const OpPassManager &pipeline : selectPassManagers)
90+
pipeline.getDependentDialects(registry);
91+
}
92+
93+
void runOnOperation() override {
94+
Operation *op = getOperation();
95+
Attribute condAttrValue = op->getAttr(condAttrName);
96+
if (!condAttrValue) {
97+
op->emitError("Condition attribute not present: ") << condAttrName;
98+
return signalPassFailure();
99+
}
100+
101+
for (auto &&[value, pm] :
102+
llvm::zip_equal(selectAttrs, selectPassManagers)) {
103+
if (value != condAttrValue)
104+
continue;
105+
106+
if (failed(runPipeline(pm, op)))
107+
return signalPassFailure();
108+
109+
return;
110+
}
111+
112+
op->emitError("Unhandled condition value: ") << condAttrValue;
113+
return signalPassFailure();
114+
}
115+
116+
protected:
117+
StringRef getName() const override { return name; }
118+
119+
private:
120+
StringAttr condAttrName;
121+
SmallVector<Attribute> selectAttrs;
122+
SmallVector<OpPassManager> selectPassManagers;
123+
};
124+
} // namespace
125+
126+
std::unique_ptr<Pass> mlir::createSelectPass(
127+
std::string name, std::string selectCondName,
128+
ArrayRef<std::pair<StringRef, std::function<void(OpPassManager &)>>>
129+
populateFuncs) {
130+
return std::make_unique<SelectPass>(std::move(name),
131+
std::move(selectCondName), populateFuncs);
132+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(gpu.module(select-pass{ \
2+
// RUN: name=TestSelectPass \
3+
// RUN: select-cond-name=test.attr \
4+
// RUN: select-values=rocdl,nvvm \
5+
// RUN: select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \
6+
// RUN: }))' -split-input-file | FileCheck %s
7+
8+
gpu.module @rocdl_module attributes {test.attr = "rocdl"} {
9+
// CHECK-LABEL: func @foo()
10+
// CHECK: rocdl.workitem.id.x
11+
func.func @foo() -> index {
12+
%0 = gpu.thread_id x
13+
return %0 : index
14+
}
15+
}
16+
17+
gpu.module @nvvm_module attributes {test.attr = "nvvm"} {
18+
// CHECK-LABEL: func @bar()
19+
// CHECK: nvvm.read.ptx.sreg.tid.x
20+
func.func @bar() -> index {
21+
%0 = gpu.thread_id x
22+
return %0 : index
23+
}
24+
}

0 commit comments

Comments
 (0)