Skip to content

Commit 2b6ecd7

Browse files
committed
[CIR] Add cir-simplify pass
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.
1 parent d6dbe77 commit 2b6ecd7

File tree

16 files changed

+378
-15
lines changed

16 files changed

+378
-15
lines changed

clang/include/clang/CIR/CIRToCIRPasses.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ namespace cir {
3232
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
3333
mlir::MLIRContext &mlirCtx,
3434
clang::ASTContext &astCtx,
35-
bool enableVerifier);
35+
bool enableVerifier,
36+
bool enableCIRSimplify);
3637

3738
} // namespace cir
3839

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
2727
let useDefaultAttributePrinterParser = 0;
2828
let useDefaultTypePrinterParser = 0;
2929

30+
let hasConstantMaterializer = 1;
31+
3032
let extraClassDeclaration = [{
3133
static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
3234

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
14641464
qualified(type($false_value))
14651465
`)` `->` qualified(type($result)) attr-dict
14661466
}];
1467+
1468+
let hasFolder = 1;
14671469
}
14681470

14691471
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace mlir {
2222

2323
std::unique_ptr<Pass> createCIRCanonicalizePass();
2424
std::unique_ptr<Pass> createCIRFlattenCFGPass();
25+
std::unique_ptr<Pass> createCIRSimplifyPass();
2526
std::unique_ptr<Pass> createHoistAllocasPass();
2627

2728
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);

clang/include/clang/CIR/Dialect/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
2929
let dependentDialects = ["cir::CIRDialect"];
3030
}
3131

32+
def CIRSimplify : Pass<"cir-simplify"> {
33+
let summary = "Performs CIR simplification and code optimization";
34+
let description = [{
35+
The pass performs code simplification and optimization on CIR.
36+
37+
Unlike the `cir-canonicalize` pass, this pass contains more aggresive code
38+
transformations that could significantly affect CIR-to-source fidelity.
39+
Example transformations performed in this pass include ternary folding,
40+
code hoisting, etc.
41+
}];
42+
let constructor = "mlir::createCIRSimplifyPass()";
43+
let dependentDialects = ["cir::CIRDialect"];
44+
}
45+
3246
def HoistAllocas : Pass<"cir-hoist-allocas"> {
3347
let summary = "Hoist allocas to the entry of the function";
3448
let description = [{

clang/include/clang/CIR/FrontendAction/CIRGenAction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
4949
public:
5050
~CIRGenAction() override;
5151

52-
OutputType Action;
52+
OutputType action;
5353
};
5454

5555
class EmitCIRAction : public CIRGenAction {

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ struct MissingFeatures {
199199
static bool labelOp() { return false; }
200200
static bool ptrDiffOp() { return false; }
201201
static bool ptrStrideOp() { return false; }
202-
static bool selectOp() { return false; }
203202
static bool switchOp() { return false; }
204203
static bool ternaryOp() { return false; }
205204
static bool tryOp() { return false; }

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
7979
addInterfaces<CIROpAsmDialectInterface>();
8080
}
8181

82+
Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
83+
mlir::Attribute value,
84+
mlir::Type type,
85+
mlir::Location loc) {
86+
return builder.create<cir::ConstantOp>(loc, type,
87+
mlir::cast<mlir::TypedAttr>(value));
88+
}
89+
8290
//===----------------------------------------------------------------------===//
8391
// Helpers
8492
//===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
12611269
result.addTypes(TypeRange{yield.getOperandTypes().front()});
12621270
}
12631271

1272+
//===----------------------------------------------------------------------===//
1273+
// SelectOp
1274+
//===----------------------------------------------------------------------===//
1275+
1276+
OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1277+
mlir::Attribute condition = adaptor.getCondition();
1278+
if (condition) {
1279+
bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1280+
return conditionValue ? getTrueValue() : getFalseValue();
1281+
}
1282+
1283+
// cir.select if %0 then x else x -> x
1284+
mlir::Attribute trueValue = adaptor.getTrueValue();
1285+
mlir::Attribute falseValue = adaptor.getFalseValue();
1286+
if (trueValue == falseValue)
1287+
return trueValue;
1288+
if (getTrueValue() == getFalseValue())
1289+
return getTrueValue();
1290+
1291+
return {};
1292+
}
1293+
12641294
//===----------------------------------------------------------------------===//
12651295
// ShiftOp
12661296
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
121121
getOperation()->walk([&](Operation *op) {
122122
assert(!cir::MissingFeatures::switchOp());
123123
assert(!cir::MissingFeatures::tryOp());
124-
assert(!cir::MissingFeatures::selectOp());
125124
assert(!cir::MissingFeatures::complexCreateOp());
126125
assert(!cir::MissingFeatures::complexRealOp());
127126
assert(!cir::MissingFeatures::complexImagOp());
128127
assert(!cir::MissingFeatures::callOp());
129128
// CastOp and UnaryOp are here to perform a manual `fold` in
130129
// applyOpPatternsGreedily.
131-
if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
130+
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
132131
ops.push_back(op);
133132
});
134133

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/IR/Block.h"
12+
#include "mlir/IR/Operation.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/IR/Region.h"
15+
#include "mlir/Support/LogicalResult.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
18+
#include "clang/CIR/Dialect/Passes.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
21+
using namespace mlir;
22+
using namespace cir;
23+
24+
//===----------------------------------------------------------------------===//
25+
// Rewrite patterns
26+
//===----------------------------------------------------------------------===//
27+
28+
namespace {
29+
30+
/// Simplify suitable ternary operations into select operations.
31+
///
32+
/// For now we only simplify those ternary operations whose true and false
33+
/// branches directly yield a value or a constant. That is, both of the true and
34+
/// the false branch must either contain a cir.yield operation as the only
35+
/// operation in the branch, or contain a cir.const operation followed by a
36+
/// cir.yield operation that yields the constant value.
37+
///
38+
/// For example, we will simplify the following ternary operation:
39+
///
40+
/// %0 = cir.ternary (%condition, true {
41+
/// %1 = cir.const ...
42+
/// cir.yield %1
43+
/// } false {
44+
/// cir.yield %2
45+
/// })
46+
///
47+
/// into the following sequence of operations:
48+
///
49+
/// %1 = cir.const ...
50+
/// %0 = cir.select if %condition then %1 else %2
51+
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
52+
using OpRewritePattern<TernaryOp>::OpRewritePattern;
53+
54+
LogicalResult matchAndRewrite(TernaryOp op,
55+
PatternRewriter &rewriter) const override {
56+
if (op->getNumResults() != 1)
57+
return mlir::failure();
58+
59+
if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
60+
!isSimpleTernaryBranch(op.getFalseRegion()))
61+
return mlir::failure();
62+
63+
cir::YieldOp trueBranchYieldOp =
64+
mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
65+
cir::YieldOp falseBranchYieldOp =
66+
mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
67+
mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
68+
mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
69+
70+
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
71+
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
72+
rewriter.eraseOp(trueBranchYieldOp);
73+
rewriter.eraseOp(falseBranchYieldOp);
74+
rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
75+
falseValue);
76+
77+
return mlir::success();
78+
}
79+
80+
private:
81+
bool isSimpleTernaryBranch(mlir::Region &region) const {
82+
if (!region.hasOneBlock())
83+
return false;
84+
85+
mlir::Block &onlyBlock = region.front();
86+
mlir::Block::OpListType &ops = onlyBlock.getOperations();
87+
88+
// The region/block could only contain at most 2 operations.
89+
if (ops.size() > 2)
90+
return false;
91+
92+
if (ops.size() == 1) {
93+
// The region/block only contain a cir.yield operation.
94+
return true;
95+
}
96+
97+
// Check whether the region/block contains a cir.const followed by a
98+
// cir.yield that yields the value.
99+
auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
100+
auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
101+
yieldOp.getArgs()[0].getDefiningOp());
102+
return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
103+
}
104+
};
105+
106+
struct SimplifySelect : public OpRewritePattern<SelectOp> {
107+
using OpRewritePattern<SelectOp>::OpRewritePattern;
108+
109+
LogicalResult matchAndRewrite(SelectOp op,
110+
PatternRewriter &rewriter) const final {
111+
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
112+
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
113+
auto trueValueConstOp =
114+
mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
115+
auto falseValueConstOp =
116+
mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
117+
if (!trueValueConstOp || !falseValueConstOp)
118+
return mlir::failure();
119+
120+
auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
121+
auto falseValue =
122+
mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
123+
if (!trueValue || !falseValue)
124+
return mlir::failure();
125+
126+
// cir.select if %0 then #true else #false -> %0
127+
if (trueValue.getValue() && !falseValue.getValue()) {
128+
rewriter.replaceAllUsesWith(op, op.getCondition());
129+
rewriter.eraseOp(op);
130+
return mlir::success();
131+
}
132+
133+
// cir.select if %0 then #false else #true -> cir.unary not %0
134+
if (!trueValue.getValue() && falseValue.getValue()) {
135+
rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
136+
op.getCondition());
137+
return mlir::success();
138+
}
139+
140+
return mlir::failure();
141+
}
142+
};
143+
144+
//===----------------------------------------------------------------------===//
145+
// CIRSimplifyPass
146+
//===----------------------------------------------------------------------===//
147+
148+
struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
149+
using CIRSimplifyBase::CIRSimplifyBase;
150+
151+
void runOnOperation() override;
152+
};
153+
154+
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
155+
// clang-format off
156+
patterns.add<
157+
SimplifyTernary,
158+
SimplifySelect
159+
>(patterns.getContext());
160+
// clang-format on
161+
}
162+
163+
void CIRSimplifyPass::runOnOperation() {
164+
// Collect rewrite patterns.
165+
RewritePatternSet patterns(&getContext());
166+
populateMergeCleanupPatterns(patterns);
167+
168+
// Collect operations to apply patterns.
169+
llvm::SmallVector<Operation *, 16> ops;
170+
getOperation()->walk([&](Operation *op) {
171+
if (isa<TernaryOp, SelectOp>(op))
172+
ops.push_back(op);
173+
});
174+
175+
// Apply patterns.
176+
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
177+
signalPassFailure();
178+
}
179+
180+
} // namespace
181+
182+
std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
183+
return std::make_unique<CIRSimplifyPass>();
184+
}

0 commit comments

Comments
 (0)