Skip to content

Commit ff8a776

Browse files
committed
[mlir][flang] Added Weighted[Region]BranchOpInterface's.
The new interfaces provide getters and setters for the weight information about the branches of BranchOpInterface and RegionBranchOpInterface operations. These interfaces are done the same way as LLVM dialect's BranchWeightOpInterface. The plan is to produce this information in Flang, e.g. mark most probably "cold" code as such and allow LLVM to order basic blocks accordingly. An example of such a code is copy loops generated for arrays repacking - we can mark it as "cold" assuming that the copy will not happen dynamically. If the copy actually happens the overhead of the copy is probably high enough so that we may not care about the little overhead of jumping to the "cold" code and fetching it.
1 parent d8118ed commit ff8a776

File tree

14 files changed

+396
-22
lines changed

14 files changed

+396
-22
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
23232323
}];
23242324
}
23252325

2326-
def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
2327-
"getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
2328-
NoRegionArguments]> {
2326+
def fir_IfOp
2327+
: region_Op<
2328+
"if", [DeclareOpInterfaceMethods<
2329+
RegionBranchOpInterface, ["getRegionInvocationBounds",
2330+
"getEntrySuccessorRegions"]>,
2331+
RecursiveMemoryEffects, NoRegionArguments,
2332+
WeightedRegionBranchOpInterface]> {
23292333
let summary = "if-then-else conditional operation";
23302334
let description = [{
23312335
Used to conditionally execute operations. This operation is the FIR
@@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
23422346
```
23432347
}];
23442348

2345-
let arguments = (ins I1:$condition);
2349+
let arguments = (ins I1:$condition,
2350+
OptionalAttr<DenseI32ArrayAttr>:$region_weights);
23462351
let results = (outs Variadic<AnyType>:$results);
23472352

23482353
let regions = (region
@@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
23712376

23722377
void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
23732378
unsigned resultNum);
2379+
2380+
/// Returns the display name string for the region_weights attribute.
2381+
static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
2382+
return "weights";
2383+
}
23742384
}];
23752385
}
23762386

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
44184418
parser.resolveOperand(cond, i1Type, result.operands))
44194419
return mlir::failure();
44204420

4421+
if (mlir::succeeded(
4422+
parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
4423+
if (parser.parseLParen())
4424+
return mlir::failure();
4425+
mlir::DenseI32ArrayAttr weights;
4426+
if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
4427+
return mlir::failure();
4428+
if (weights)
4429+
result.addAttribute(getRegionWeightsAttrName(result.name), weights);
4430+
if (parser.parseRParen())
4431+
return mlir::failure();
4432+
}
4433+
44214434
if (parser.parseOptionalArrowTypeList(result.types))
44224435
return mlir::failure();
44234436

@@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
44494462
void fir::IfOp::print(mlir::OpAsmPrinter &p) {
44504463
bool printBlockTerminators = false;
44514464
p << ' ' << getCondition();
4465+
if (auto weights = getRegionWeightsAttr()) {
4466+
p << ' ' << getWeightsAttrAssemblyName() << '(';
4467+
p.printStrippedAttrOrType(weights);
4468+
p << ')';
4469+
}
44524470
if (!getResults().empty()) {
44534471
p << " -> (" << getResultTypes() << ')';
44544472
printBlockTerminators = true;
@@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
44644482
p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
44654483
printBlockTerminators);
44664484
}
4467-
p.printOptionalAttrDict((*this)->getAttrs());
4485+
p.printOptionalAttrDict((*this)->getAttrs(),
4486+
/*elideAttrs=*/{getRegionWeightsAttrName()});
44684487
}
44694488

44704489
void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,

flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,11 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
212212
}
213213

214214
rewriter.setInsertionPointToEnd(condBlock);
215-
rewriter.create<mlir::cf::CondBranchOp>(
215+
auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
216216
loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
217217
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
218+
if (auto weights = ifOp.getRegionWeightsOrNull())
219+
branchOp.setBranchWeights(weights);
218220
rewriter.replaceOp(ifOp, continueBlock->getArguments());
219221
return success();
220222
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
2+
3+
func.func private @callee() -> none
4+
5+
// CHECK-LABEL: func.func @if_then(
6+
// CHECK-SAME: %[[ARG0:.*]]: i1) {
7+
// CHECK: cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
8+
// CHECK: ^bb1:
9+
// CHECK: %[[VAL_0:.*]] = fir.call @callee() : () -> none
10+
// CHECK: cf.br ^bb2
11+
// CHECK: ^bb2:
12+
// CHECK: return
13+
// CHECK: }
14+
func.func @if_then(%cond: i1) {
15+
fir.if %cond weights([10, 90]) {
16+
fir.call @callee() : () -> none
17+
}
18+
return
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: func.func @if_then_else(
24+
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i32 {
25+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
26+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
27+
// CHECK: cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
28+
// CHECK: ^bb1:
29+
// CHECK: cf.br ^bb3(%[[VAL_0]] : i32)
30+
// CHECK: ^bb2:
31+
// CHECK: cf.br ^bb3(%[[VAL_1]] : i32)
32+
// CHECK: ^bb3(%[[VAL_2:.*]]: i32):
33+
// CHECK: cf.br ^bb4
34+
// CHECK: ^bb4:
35+
// CHECK: return %[[VAL_2]] : i32
36+
// CHECK: }
37+
func.func @if_then_else(%cond: i1) -> i32 {
38+
%c0 = arith.constant 0 : i32
39+
%c1 = arith.constant 1 : i32
40+
%result = fir.if %cond weights([90, 10]) -> i32 {
41+
fir.result %c0 : i32
42+
} else {
43+
fir.result %c1 : i32
44+
}
45+
return %result : i32
46+
}

flang/test/Fir/fir-ops.fir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
10151015
%6 = arith.addi %2, %5 : index
10161016
return %6 : index
10171017
}
1018+
1019+
// CHECK-LABEL: func.func @test_if_weights(
1020+
// CHECK-SAME: %[[ARG0:.*]]: i1) {
1021+
func.func @test_if_weights(%cond: i1) {
1022+
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
1023+
// CHECK: }
1024+
fir.if %cond weights([99, 1]) {
1025+
}
1026+
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
1027+
// CHECK: } else {
1028+
// CHECK: }
1029+
fir.if %cond weights ([99,1]) {
1030+
} else {
1031+
}
1032+
return
1033+
}

flang/test/Fir/invalid.fir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,3 +1393,40 @@ fir.local {type = local_init} @x.localizer : f32 init {
13931393
^bb0(%arg0: f32, %arg1: f32):
13941394
fir.yield(%arg0 : f32)
13951395
}
1396+
1397+
// -----
1398+
1399+
func.func @wrong_weights_number_in_if_then(%cond: i1) {
1400+
// expected-error @below {{number of weights (1) does not match the number of regions (2)}}
1401+
fir.if %cond weights([50]) {
1402+
}
1403+
return
1404+
}
1405+
1406+
// -----
1407+
1408+
func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
1409+
// expected-error @below {{number of weights (3) does not match the number of regions (2)}}
1410+
fir.if %cond weights([50, 40, 10]) {
1411+
} else {
1412+
}
1413+
return
1414+
}
1415+
1416+
// -----
1417+
1418+
func.func @negative_weight_in_if_then(%cond: i1) {
1419+
// expected-error @below {{weight #0 must be non-negative}}
1420+
fir.if %cond weights([-1, 101]) {
1421+
}
1422+
return
1423+
}
1424+
1425+
// -----
1426+
1427+
func.func @wrong_total_weight_in_if_then(%cond: i1) {
1428+
// expected-error @below {{total weight 101 is not 100}}
1429+
fir.if %cond weights([1, 100]) {
1430+
}
1431+
return
1432+
}

mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
112112
// CondBranchOp
113113
//===----------------------------------------------------------------------===//
114114

115-
def CondBranchOp : CF_Op<"cond_br",
116-
[AttrSizedOperandSegments,
117-
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
118-
Pure, Terminator]> {
115+
def CondBranchOp
116+
: CF_Op<"cond_br", [AttrSizedOperandSegments,
117+
DeclareOpInterfaceMethods<
118+
BranchOpInterface, ["getSuccessorForOperands"]>,
119+
WeightedBranchOpInterface, Pure, Terminator]> {
119120
let summary = "Conditional branch operation";
120121
let description = [{
121122
The `cf.cond_br` terminator operation represents a conditional branch on a
@@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
144145
```
145146
}];
146147

147-
let arguments = (ins I1:$condition,
148-
Variadic<AnyType>:$trueDestOperands,
149-
Variadic<AnyType>:$falseDestOperands);
148+
let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
149+
Variadic<AnyType>:$falseDestOperands,
150+
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
150151
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
151152

152-
let builders = [
153-
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
154-
"ValueRange":$trueOperands, "Block *":$falseDest,
155-
"ValueRange":$falseOperands), [{
156-
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
153+
let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
154+
"ValueRange":$trueOperands,
155+
"Block *":$falseDest,
156+
"ValueRange":$falseOperands),
157+
[{
158+
build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
157159
falseDest);
158160
}]>,
159-
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
160-
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
161+
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
162+
"Block *":$falseDest,
163+
CArg<"ValueRange", "{}">:$falseOperands),
164+
[{
161165
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
162166
falseOperands);
163167
}]>];
@@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
216220

217221
let hasCanonicalizer = 1;
218222
let assemblyFormat = [{
219-
$condition `,`
223+
$condition (`weights` `(` $branch_weights^ `)` )? `,`
220224
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
221225
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
222226
attr-dict

mlir/include/mlir/Interfaces/ControlFlowInterfaces.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
142142
const SuccessorOperands &operands);
143143
} // namespace detail
144144

145+
//===----------------------------------------------------------------------===//
146+
// WeightedBranchOpInterface
147+
//===----------------------------------------------------------------------===//
148+
149+
namespace detail {
150+
/// Verify that the branch weights attached to an operation
151+
/// implementing WeightedBranchOpInterface are correct.
152+
LogicalResult verifyBranchWeights(Operation *op);
153+
} // namespace detail
154+
155+
//===----------------------------------------------------------------------===//
156+
// WeightedRegiobBranchOpInterface
157+
//===----------------------------------------------------------------------===//
158+
159+
namespace detail {
160+
/// Verify that the region weights attached to an operation
161+
/// implementing WeightedRegiobBranchOpInterface are correct.
162+
LogicalResult verifyRegionBranchWeights(Operation *op);
163+
} // namespace detail
164+
145165
//===----------------------------------------------------------------------===//
146166
// RegionBranchOpInterface
147167
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)