Skip to content

Commit c67ab47

Browse files
committed
Replaced LLVM dialect's BranchWeightOpInterface with WeightedBranchOpInterface.
1 parent ff8a776 commit c67ab47

File tree

9 files changed

+68
-126
lines changed

9 files changed

+68
-126
lines changed

flang/test/Fir/invalid.fir

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ fir.local {type = local_init} @x.localizer : f32 init {
13971397
// -----
13981398

13991399
func.func @wrong_weights_number_in_if_then(%cond: i1) {
1400-
// expected-error @below {{number of weights (1) does not match the number of regions (2)}}
1400+
// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}}
14011401
fir.if %cond weights([50]) {
14021402
}
14031403
return
@@ -1406,7 +1406,7 @@ func.func @wrong_weights_number_in_if_then(%cond: i1) {
14061406
// -----
14071407

14081408
func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
1409-
// expected-error @below {{number of weights (3) does not match the number of regions (2)}}
1409+
// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}}
14101410
fir.if %cond weights([50, 40, 10]) {
14111411
} else {
14121412
}
@@ -1421,12 +1421,3 @@ func.func @negative_weight_in_if_then(%cond: i1) {
14211421
}
14221422
return
14231423
}
1424-
1425-
// -----
1426-
1427-
func.func @wrong_total_weight_in_if_then(%cond: i1) {
1428-
// expected-error @below {{total weight 101 is not 100}}
1429-
fir.if %cond weights([1, 100]) {
1430-
}
1431-
return
1432-
}

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
168168
];
169169
}
170170

171-
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
172-
let description = [{
173-
An interface for operations that can carry branch weights metadata. It
174-
provides setters and getters for the operation's branch weights attribute.
175-
The default implementation of the interface methods expect the operation to
176-
have an attribute of type DenseI32ArrayAttr named branch_weights.
177-
}];
178-
179-
let cppNamespace = "::mlir::LLVM";
180-
181-
let methods = [
182-
InterfaceMethod<
183-
/*desc=*/ "Returns the branch weights attribute or nullptr",
184-
/*returnType=*/ "::mlir::DenseI32ArrayAttr",
185-
/*methodName=*/ "getBranchWeightsOrNull",
186-
/*args=*/ (ins),
187-
/*methodBody=*/ [{}],
188-
/*defaultImpl=*/ [{
189-
auto op = cast<ConcreteOp>(this->getOperation());
190-
return op.getBranchWeightsAttr();
191-
}]
192-
>,
193-
InterfaceMethod<
194-
/*desc=*/ "Sets the branch weights attribute",
195-
/*returnType=*/ "void",
196-
/*methodName=*/ "setBranchWeights",
197-
/*args=*/ (ins "::mlir::DenseI32ArrayAttr":$attr),
198-
/*methodBody=*/ [{}],
199-
/*defaultImpl=*/ [{
200-
auto op = cast<ConcreteOp>(this->getOperation());
201-
op.setBranchWeightsAttr(attr);
202-
}]
203-
>
204-
];
205-
}
206-
207171
def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
208172
let description = [{
209173
An interface for memory operations that can carry access groups metadata.

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
660660
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
661661

662662
// Call-related operations.
663-
def LLVM_InvokeOp : LLVM_Op<"invoke", [
664-
AttrSizedOperandSegments,
665-
DeclareOpInterfaceMethods<BranchOpInterface>,
666-
DeclareOpInterfaceMethods<CallOpInterface>,
667-
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
668-
Terminator]> {
663+
def LLVM_InvokeOp
664+
: LLVM_Op<"invoke", [AttrSizedOperandSegments,
665+
DeclareOpInterfaceMethods<BranchOpInterface>,
666+
DeclareOpInterfaceMethods<CallOpInterface>,
667+
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
668+
Terminator]> {
669669
let arguments = (ins
670670
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
671671
OptionalAttr<FlatSymbolRefAttr>:$callee,
@@ -734,12 +734,13 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
734734
// CallOp
735735
//===----------------------------------------------------------------------===//
736736

737-
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
738-
[AttrSizedOperandSegments,
739-
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
740-
DeclareOpInterfaceMethods<CallOpInterface>,
741-
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
742-
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
737+
def LLVM_CallOp
738+
: LLVM_MemAccessOpBase<
739+
"call", [AttrSizedOperandSegments,
740+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
741+
DeclareOpInterfaceMethods<CallOpInterface>,
742+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
743+
DeclareOpInterfaceMethods<WeightedBranchOpInterface>]> {
743744
let summary = "Call to an LLVM function.";
744745
let description = [{
745746
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@@ -1047,11 +1048,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
10471048
LLVM_TerminatorPassthroughOpBuilder
10481049
];
10491050
}
1050-
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
1051-
[AttrSizedOperandSegments,
1052-
DeclareOpInterfaceMethods<BranchOpInterface>,
1053-
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
1054-
Pure]> {
1051+
def LLVM_CondBrOp
1052+
: LLVM_TerminatorOp<
1053+
"cond_br", [AttrSizedOperandSegments,
1054+
DeclareOpInterfaceMethods<BranchOpInterface>,
1055+
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
1056+
Pure]> {
10551057
let arguments = (ins I1:$condition,
10561058
Variadic<LLVM_Type>:$trueDestOperands,
10571059
Variadic<LLVM_Type>:$falseDestOperands,
@@ -1136,11 +1138,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
11361138
}];
11371139
}
11381140

1139-
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
1140-
[AttrSizedOperandSegments,
1141-
DeclareOpInterfaceMethods<BranchOpInterface>,
1142-
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
1143-
Pure]> {
1141+
def LLVM_SwitchOp
1142+
: LLVM_TerminatorOp<
1143+
"switch", [AttrSizedOperandSegments,
1144+
DeclareOpInterfaceMethods<BranchOpInterface>,
1145+
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
1146+
Pure]> {
11441147
let arguments = (ins
11451148
AnySignlessInteger:$value,
11461149
Variadic<AnyType>:$defaultOperands,

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,12 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
385385
operations, i.e. terminator operations with successors.
386386

387387
This interface provides methods for getting/setting integer non-negative
388-
weight of each branch in the range from 0 to 100. The sum of weights
389-
must be 100. The number of weights must match the number of successors
390-
of the operation.
391-
392-
The weights specify the probability (in percents) of taking
393-
a particular branch.
388+
weight of each branch. The probability of executing a branch
389+
is computed as the ratio between the branch's weight and the total
390+
sum of the weights.
391+
The number of weights must match the number of successors of the operation,
392+
with one exception for CallOpInterface operations, which may only
393+
have on weight when they do not have any successors.
394394

395395
The default implementations of the methods expect the operation
396396
to have an attribute of type DenseI32ArrayAttr named branch_weights.
@@ -440,15 +440,16 @@ def WeightedRegionBranchOpInterface
440440
that exhibit branching behavior between held regions.
441441

442442
This interface provides methods for getting/setting integer non-negative
443-
weight of each branch in the range from 0 to 100. The sum of weights
444-
must be 100. The number of weights must match the number of regions
443+
weight of each branch. The probability of executing a region is computed
444+
as the ratio between the region branch's weight and the total sum
445+
of the weights.
446+
The number of weights must match the number of regions
445447
held by the operation (including empty regions).
446448

447-
The weights specify the probability (in percents) of branching
448-
to a particular region when first executing the operation.
449+
The weights specify the probability of branching to a particular
450+
region when first executing the operation.
449451
For example, for loop-like operations with a single region
450452
the weight specifies the probability of entering the loop.
451-
In this case, the weight must be either 0 or 100.
452453

453454
The default implementations of the methods expect the operation
454455
to have an attribute of type DenseI32ArrayAttr named branch_weights.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ class ModuleTranslation {
189189
llvm::Instruction *inst);
190190

191191
/// Sets LLVM profiling metadata for operations that have branch weights.
192-
void setBranchWeightsMetadata(BranchWeightOpInterface op);
192+
void setBranchWeightsMetadata(WeightedBranchOpInterface op);
193193

194194
/// Sets LLVM loop metadata for branch operations that have a loop annotation
195195
/// attribute.

mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <utility>
1010

1111
#include "mlir/IR/BuiltinTypes.h"
12+
#include "mlir/Interfaces/CallInterfaces.h"
1213
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1314
#include "llvm/ADT/SmallPtrSet.h"
1415

@@ -84,24 +85,33 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
8485
// WeightedBranchOpInterface
8586
//===----------------------------------------------------------------------===//
8687

87-
LogicalResult detail::verifyBranchWeights(Operation *op) {
88-
auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
88+
static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
89+
int64_t weightsNum,
90+
llvm::StringRef weightAnchorName,
91+
llvm::StringRef weightRefName) {
8992
if (weights) {
90-
if (weights.size() != op->getNumSuccessors())
91-
return op->emitError() << "number of weights (" << weights.size()
92-
<< ") does not match the number of successors ("
93-
<< op->getNumSuccessors() << ")";
94-
int32_t total = 0;
95-
for (auto weight : llvm::enumerate(weights.asArrayRef())) {
93+
if (weights.size() != weightsNum)
94+
return op->emitError() << "expects number of " << weightAnchorName
95+
<< " weights to match number of " << weightRefName
96+
<< ": " << weights.size() << " vs " << weightsNum;
97+
98+
for (auto weight : llvm::enumerate(weights.asArrayRef()))
9699
if (weight.value() < 0)
97100
return op->emitError()
98101
<< "weight #" << weight.index() << " must be non-negative";
99-
total += weight.value();
100-
}
101-
if (total != 100)
102-
return op->emitError() << "total weight " << total << " is not 100";
103102
}
104-
return mlir::success();
103+
return success();
104+
}
105+
106+
LogicalResult detail::verifyBranchWeights(Operation *op) {
107+
auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
108+
unsigned successorsNum = op->getNumSuccessors();
109+
// CallOpInterface operations without successors may only have
110+
// one weight, though it seems to be redundant and indicate
111+
// 100% probability of calling the callee(s).
112+
int64_t weightsNum =
113+
(successorsNum == 0 && isa<CallOpInterface>(op)) ? 1 : successorsNum;
114+
return verifyWeights(op, weights, weightsNum, "branch", "successors");
105115
}
106116

107117
//===----------------------------------------------------------------------===//
@@ -111,22 +121,7 @@ LogicalResult detail::verifyBranchWeights(Operation *op) {
111121
LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
112122
auto weights =
113123
cast<WeightedRegionBranchOpInterface>(op).getRegionWeightsOrNull();
114-
if (weights) {
115-
if (weights.size() != op->getNumRegions())
116-
return op->emitError() << "number of weights (" << weights.size()
117-
<< ") does not match the number of regions ("
118-
<< op->getNumRegions() << ")";
119-
int32_t total = 0;
120-
for (auto weight : llvm::enumerate(weights.asArrayRef())) {
121-
if (weight.value() < 0)
122-
return op->emitError()
123-
<< "weight #" << weight.index() << " must be non-negative";
124-
total += weight.value();
125-
}
126-
if (total != 100)
127-
return op->emitError() << "total weight " << total << " is not 100";
128-
}
129-
return mlir::success();
124+
return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
130125
}
131126

132127
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
146146
branchWeights.push_back(branchWeight->getZExtValue());
147147
}
148148

149-
if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
149+
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
150150
iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
151151
return success();
152152
}

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,7 @@ LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
10551055
return failure();
10561056

10571057
// Set the branch weight metadata on the translated instruction.
1058-
if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
1058+
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op))
10591059
setBranchWeightsMetadata(iface);
10601060
}
10611061

@@ -2026,7 +2026,7 @@ void ModuleTranslation::setDereferenceableMetadata(
20262026
inst->setMetadata(kindId, derefSizeNode);
20272027
}
20282028

2029-
void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
2029+
void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
20302030
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
20312031
if (!weightsAttr)
20322032
return;

mlir/test/Dialect/ControlFlow/invalid.mlir

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) {
7272

7373
// CHECK-LABEL: func @wrong_weights_number
7474
func.func @wrong_weights_number(%cond: i1) {
75-
// expected-error@+1 {{number of weights (1) does not match the number of successors (2)}}
75+
// expected-error@+1 {{expects number of branch weights to match number of successors: 1 vs 2}}
7676
cf.cond_br %cond weights([100]), ^bb1, ^bb2
7777
^bb1:
7878
return
@@ -91,15 +91,3 @@ func.func @wrong_total_weight(%cond: i1) {
9191
^bb2:
9292
return
9393
}
94-
95-
// -----
96-
97-
// CHECK-LABEL: func @wrong_total_weight
98-
func.func @wrong_total_weight(%cond: i1) {
99-
// expected-error@+1 {{total weight 101 is not 100}}
100-
cf.cond_br %cond weights([100, 1]), ^bb1, ^bb2
101-
^bb1:
102-
return
103-
^bb2:
104-
return
105-
}

0 commit comments

Comments
 (0)