Skip to content

Commit 5b989c5

Browse files
committed
[CIR] Upstream TernaryOp
This patch adds TernaryOp to CIR plus a pass that flattens the operator in FlattenCFG.
1 parent a7402b0 commit 5b989c5

File tree

6 files changed

+281
-8
lines changed

6 files changed

+281
-8
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,9 @@ def ConditionOp : CIR_Op<"condition", [
610610
//===----------------------------------------------------------------------===//
611611

612612
def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
613-
ParentOneOf<["IfOp", "ScopeOp", "SwitchOp",
614-
"WhileOp", "ForOp", "CaseOp",
615-
"DoWhileOp"]>]> {
613+
ParentOneOf<["CaseOp", "DoWhileOp", "ForOp",
614+
"IfOp", "ScopeOp", "SwitchOp",
615+
"TernaryOp", "WhileOp"]>]> {
616616
let summary = "Represents the default branching behaviour of a region";
617617
let description = [{
618618
The `cir.yield` operation terminates regions on different CIR operations,
@@ -1462,6 +1462,59 @@ def SelectOp : CIR_Op<"select", [Pure,
14621462
}];
14631463
}
14641464

1465+
//===----------------------------------------------------------------------===//
1466+
// TernaryOp
1467+
//===----------------------------------------------------------------------===//
1468+
1469+
def TernaryOp : CIR_Op<"ternary",
1470+
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
1471+
RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
1472+
let summary = "The `cond ? a : b` C/C++ ternary operation";
1473+
let description = [{
1474+
The `cir.ternary` operation represents C/C++ ternary, much like a `select`
1475+
operation. The first argument is a `cir.bool` condition to evaluate, followed
1476+
by two regions to execute (true or false). This is different from `cir.if`
1477+
since each region is one block sized and the `cir.yield` closing the block
1478+
scope should have one argument.
1479+
1480+
Example:
1481+
1482+
```mlir
1483+
// x = cond ? a : b;
1484+
1485+
%x = cir.ternary (%cond, true_region {
1486+
...
1487+
cir.yield %a : i32
1488+
}, false_region {
1489+
...
1490+
cir.yield %b : i32
1491+
}) -> i32
1492+
```
1493+
}];
1494+
let arguments = (ins CIR_BoolType:$cond);
1495+
let regions = (region AnyRegion:$trueRegion,
1496+
AnyRegion:$falseRegion);
1497+
let results = (outs Optional<CIR_AnyType>:$result);
1498+
1499+
let skipDefaultBuilders = 1;
1500+
let builders = [
1501+
OpBuilder<(ins "mlir::Value":$cond,
1502+
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder,
1503+
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder)
1504+
>
1505+
];
1506+
1507+
// All constraints already verified elsewhere.
1508+
let hasVerifier = 0;
1509+
1510+
let assemblyFormat = [{
1511+
`(` $cond `,`
1512+
`true` $trueRegion `,`
1513+
`false` $falseRegion
1514+
`)` `:` functional-type(operands, results) attr-dict
1515+
}];
1516+
}
1517+
14651518
//===----------------------------------------------------------------------===//
14661519
// GlobalOp
14671520
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,48 @@ LogicalResult cir::BinOp::verify() {
11871187
return mlir::success();
11881188
}
11891189

1190+
//===----------------------------------------------------------------------===//
1191+
// TernaryOp
1192+
//===----------------------------------------------------------------------===//
1193+
1194+
/// Given the region at `index`, or the parent operation if `index` is None,
1195+
/// return the successor regions. These are the regions that may be selected
1196+
/// during the flow of control. `operands` is a set of optional attributes that
1197+
/// correspond to a constant value for each operand, or null if that operand is
1198+
/// not a constant.
1199+
void cir::TernaryOp::getSuccessorRegions(
1200+
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1201+
// The `true` and the `false` region branch back to the parent operation.
1202+
if (!point.isParent()) {
1203+
regions.push_back(RegionSuccessor(this->getODSResults(0)));
1204+
return;
1205+
}
1206+
1207+
// If the condition isn't constant, both regions may be executed.
1208+
regions.push_back(RegionSuccessor(&getTrueRegion()));
1209+
regions.push_back(RegionSuccessor(&getFalseRegion()));
1210+
}
1211+
1212+
void cir::TernaryOp::build(
1213+
OpBuilder &builder, OperationState &result, Value cond,
1214+
function_ref<void(OpBuilder &, Location)> trueBuilder,
1215+
function_ref<void(OpBuilder &, Location)> falseBuilder) {
1216+
result.addOperands(cond);
1217+
OpBuilder::InsertionGuard guard(builder);
1218+
Region *trueRegion = result.addRegion();
1219+
Block *block = builder.createBlock(trueRegion);
1220+
trueBuilder(builder, result.location);
1221+
Region *falseRegion = result.addRegion();
1222+
builder.createBlock(falseRegion);
1223+
falseBuilder(builder, result.location);
1224+
1225+
auto yield = dyn_cast<YieldOp>(block->getTerminator());
1226+
assert((yield && yield.getNumOperands() <= 1) &&
1227+
"expected zero or one result type");
1228+
if (yield.getNumOperands() == 1)
1229+
result.addTypes(TypeRange{yield.getOperandTypes().front()});
1230+
}
1231+
11901232
//===----------------------------------------------------------------------===//
11911233
// ShiftOp
11921234
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
254254
}
255255
};
256256

257+
class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
258+
public:
259+
using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
260+
261+
mlir::LogicalResult
262+
matchAndRewrite(cir::TernaryOp op,
263+
mlir::PatternRewriter &rewriter) const override {
264+
Location loc = op->getLoc();
265+
Block *condBlock = rewriter.getInsertionBlock();
266+
Block::iterator opPosition = rewriter.getInsertionPoint();
267+
Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
268+
llvm::SmallVector<mlir::Location, 2> locs;
269+
// Ternary result is optional, make sure to populate the location only
270+
// when relevant.
271+
if (op->getResultTypes().size())
272+
locs.push_back(loc);
273+
auto *continueBlock =
274+
rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
275+
rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
276+
277+
Region &trueRegion = op.getTrueRegion();
278+
Block *trueBlock = &trueRegion.front();
279+
mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
280+
rewriter.setInsertionPointToEnd(&trueRegion.back());
281+
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
282+
283+
rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
284+
continueBlock);
285+
rewriter.inlineRegionBefore(trueRegion, continueBlock);
286+
287+
Block *falseBlock = continueBlock;
288+
Region &falseRegion = op.getFalseRegion();
289+
290+
falseBlock = &falseRegion.front();
291+
mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
292+
rewriter.setInsertionPointToEnd(&falseRegion.back());
293+
cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
294+
rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
295+
continueBlock);
296+
rewriter.inlineRegionBefore(falseRegion, continueBlock);
297+
298+
rewriter.setInsertionPointToEnd(condBlock);
299+
rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
300+
301+
rewriter.replaceOp(op, continueBlock->getArguments());
302+
303+
// Ok, we're done!
304+
return mlir::success();
305+
}
306+
};
307+
257308
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
258-
patterns
259-
.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>(
260-
patterns.getContext());
309+
patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
310+
CIRScopeOpFlattening, CIRTernaryOpFlattening>(
311+
patterns.getContext());
261312
}
262313

263314
void CIRFlattenCFGPass::runOnOperation() {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
269320
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
270321
assert(!cir::MissingFeatures::ifOp());
271322
assert(!cir::MissingFeatures::switchOp());
272-
assert(!cir::MissingFeatures::ternaryOp());
273323
assert(!cir::MissingFeatures::tryOp());
274-
if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
324+
if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
275325
ops.push_back(op);
276326
});
277327

clang/test/CIR/IR/ternary.cir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-opt %s | cir-opt | FileCheck %s
2+
!u32i = !cir.int<u, 32>
3+
4+
module {
5+
cir.func @blue(%arg0: !cir.bool) -> !u32i {
6+
%0 = cir.ternary(%arg0, true {
7+
%a = cir.const #cir.int<0> : !u32i
8+
cir.yield %a : !u32i
9+
}, false {
10+
%b = cir.const #cir.int<1> : !u32i
11+
cir.yield %b : !u32i
12+
}) : (!cir.bool) -> !u32i
13+
cir.return %0 : !u32i
14+
}
15+
}
16+
17+
// CHECK: module {
18+
19+
// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
20+
// CHECK: %0 = cir.ternary(%arg0, true {
21+
// CHECK: %1 = cir.const #cir.int<0> : !u32i
22+
// CHECK: cir.yield %1 : !u32i
23+
// CHECK: }, false {
24+
// CHECK: %1 = cir.const #cir.int<1> : !u32i
25+
// CHECK: cir.yield %1 : !u32i
26+
// CHECK: }) : (!cir.bool) -> !u32i
27+
// CHECK: cir.return %0 : !u32i
28+
// CHECK: }
29+
30+
// CHECK: }
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
2+
// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
3+
4+
!u32i = !cir.int<u, 32>
5+
6+
module {
7+
cir.func @blue(%arg0: !cir.bool) -> !u32i {
8+
%0 = cir.ternary(%arg0, true {
9+
%a = cir.const #cir.int<0> : !u32i
10+
cir.yield %a : !u32i
11+
}, false {
12+
%b = cir.const #cir.int<1> : !u32i
13+
cir.yield %b : !u32i
14+
}) : (!cir.bool) -> !u32i
15+
cir.return %0 : !u32i
16+
}
17+
}
18+
19+
// LLVM-LABEL: define i32 {{.*}}@blue(
20+
// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
21+
// LLVM: br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label %[[B2:[[:alnum:]]+]]
22+
// LLVM: [[B1]]:
23+
// LLVM: br label %[[M:[[:alnum:]]+]]
24+
// LLVM: [[B2]]:
25+
// LLVM: br label %[[M]]
26+
// LLVM: [[M]]:
27+
// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
28+
// LLVM: br label %[[B3:[[:alnum:]]+]]
29+
// LLVM: [[B3]]:
30+
// LLVM: ret i32 [[R]]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @foo(%arg0: !s32i) -> !s32i {
7+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
8+
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
9+
cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
10+
%2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
11+
%3 = cir.const #cir.int<0> : !s32i
12+
%4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
13+
%5 = cir.ternary(%4, true {
14+
%7 = cir.const #cir.int<3> : !s32i
15+
cir.yield %7 : !s32i
16+
}, false {
17+
%7 = cir.const #cir.int<5> : !s32i
18+
cir.yield %7 : !s32i
19+
}) : (!cir.bool) -> !s32i
20+
cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
21+
%6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
22+
cir.return %6 : !s32i
23+
}
24+
25+
// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
26+
// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
27+
// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
28+
// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
29+
// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
30+
// CHECK: %3 = cir.const #cir.int<0> : !s32i
31+
// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
32+
// CHECK: cir.brcond %4 ^bb1, ^bb2
33+
// CHECK: ^bb1: // pred: ^bb0
34+
// CHECK: %5 = cir.const #cir.int<3> : !s32i
35+
// CHECK: cir.br ^bb3(%5 : !s32i)
36+
// CHECK: ^bb2: // pred: ^bb0
37+
// CHECK: %6 = cir.const #cir.int<5> : !s32i
38+
// CHECK: cir.br ^bb3(%6 : !s32i)
39+
// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2
40+
// CHECK: cir.br ^bb4
41+
// CHECK: ^bb4: // pred: ^bb3
42+
// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
43+
// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
44+
// CHECK: cir.return %8 : !s32i
45+
// CHECK: }
46+
47+
cir.func @foo2(%arg0: !cir.bool) {
48+
cir.ternary(%arg0, true {
49+
cir.yield
50+
}, false {
51+
cir.yield
52+
}) : (!cir.bool) -> ()
53+
cir.return
54+
}
55+
56+
// CHECK: cir.func @foo2(%arg0: !cir.bool) {
57+
// CHECK: cir.brcond %arg0 ^bb1, ^bb2
58+
// CHECK: ^bb1: // pred: ^bb0
59+
// CHECK: cir.br ^bb3
60+
// CHECK: ^bb2: // pred: ^bb0
61+
// CHECK: cir.br ^bb3
62+
// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2
63+
// CHECK: cir.br ^bb4
64+
// CHECK: ^bb4: // pred: ^bb3
65+
// CHECK: cir.return
66+
// CHECK: }
67+
68+
}

0 commit comments

Comments
 (0)