Skip to content

Commit c8b6770

Browse files
committed
Strengthened constraint on the number of weights for calls.
1 parent 08e9e40 commit c8b6770

File tree

7 files changed

+56
-45
lines changed

7 files changed

+56
-45
lines changed

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,9 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
387387
This interface provides methods for getting/setting integer non-negative
388388
weight of each branch. The probability of executing a branch
389389
is computed as the ratio between the branch's weight and the total
390-
sum of the weights.
390+
sum of the weights (which cannot be zero).
391391
The weights are optional. If they are provided, then their number
392-
must match the number of successors of the operation,
393-
with one exception for CallOpInterface operations, which may only
394-
have one weight when they do not have any successors.
392+
must match the number of successors of the operation.
395393

396394
The default implementations of the methods expect the operation
397395
to have an attribute of type DenseI32ArrayAttr named branch_weights.
@@ -445,7 +443,7 @@ def WeightedRegionBranchOpInterface
445443
This interface provides methods for getting/setting integer non-negative
446444
weight of each branch. The probability of executing a region is computed
447445
as the ratio between the region branch's weight and the total sum
448-
of the weights.
446+
of the weights (which cannot be zero).
449447
The weights are optional. If they are provided, then their number
450448
must match the number of regions held by the operation
451449
(including empty regions).

mlir/lib/Interfaces/ControlFlowInterfaces.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,17 @@ static LogicalResult verifyWeights(Operation *op,
103103
if (weight < 0)
104104
return op->emitError() << "weight #" << index << " must be non-negative";
105105

106+
if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
107+
return op->emitError() << "branch weights cannot all be zero";
108+
106109
return success();
107110
}
108111

109112
LogicalResult detail::verifyBranchWeights(Operation *op) {
110113
llvm::ArrayRef<int32_t> weights =
111114
cast<WeightedBranchOpInterface>(op).getWeights();
112-
unsigned successorsNum = op->getNumSuccessors();
113-
// CallOpInterface operations without successors may only have
114-
// one weight, though it seems to be redundant and indicate
115-
// 100% probability of calling the callee(s).
116-
// TODO: maybe we should disallow weights for calls without successors.
117-
std::size_t weightsNum =
118-
(successorsNum == 0 && isa<CallOpInterface>(op)) ? 1 : successorsNum;
119-
return verifyWeights(op, weights, weightsNum, "branch", "successors");
115+
return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
116+
"successors");
120117
}
121118

122119
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,14 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
147147
}
148148

149149
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
150-
iface.setWeights(branchWeights);
150+
// LLVM allows attaching a single weight to call instructions.
151+
// This is used for carrying the execution count information
152+
// in PGO modes. MLIR WeightedBranchOpInterface does not allow this,
153+
// so we drop the metadata in this case.
154+
// LLVM should probably use the VP form of MD_prof metadata
155+
// for such cases.
156+
if (op->getNumSuccessors() != 0)
157+
iface.setWeights(branchWeights);
151158
return success();
152159
}
153160
return failure();

mlir/test/Dialect/ControlFlow/invalid.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,15 @@ func.func @wrong_total_weight(%cond: i1) {
9191
^bb2:
9292
return
9393
}
94+
95+
// -----
96+
97+
// CHECK-LABEL: func @zero_weights
98+
func.func @wrong_total_weight(%cond: i1) {
99+
// expected-error@+1 {{branch weights cannot all be zero}}
100+
cf.cond_br %cond weights([0, 0]), ^bb1, ^bb2
101+
^bb1:
102+
return
103+
^bb2:
104+
return
105+
}

mlir/test/Target/LLVMIR/Import/metadata-profiling.ll

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,17 @@ bbd:
3636

3737
; // -----
3838

39+
; Verify that a single weight attached to a call is not translated.
40+
; The MLIR WeightedBranchOpInterface does not support this case.
41+
3942
; CHECK: llvm.func @fn()
40-
declare void @fn()
43+
declare i32 @fn()
4144

4245
; CHECK-LABEL: @call_branch_weights
43-
define void @call_branch_weights() {
44-
; CHECK: llvm.call @fn() {branch_weights = array<i32: 42>}
45-
call void @fn(), !prof !0
46-
ret void
46+
define i32 @call_branch_weights() {
47+
; CHECK: llvm.call @fn() : () -> i32
48+
%1 = call i32 @fn(), !prof !0
49+
ret i32 %1
4750
}
4851

4952
!0 = !{!"branch_weights", i32 42}

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,23 @@ llvm.mlir.global external constant @const() {addr_space = 0 : i32, dso_local} :
448448
}
449449

450450
llvm.func extern_weak @extern_func()
451+
452+
// -----
453+
454+
llvm.func @fn()
455+
456+
llvm.func @call_branch_weights() {
457+
// expected-error @below{{expects number of branch weights to match number of successors: 1 vs 0}}
458+
llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
459+
llvm.return
460+
}
461+
462+
// -----
463+
464+
llvm.func @fn() -> i32
465+
466+
llvm.func @call_branch_weights() {
467+
// expected-error @below{{expects number of branch weights to match number of successors: 1 vs 0}}
468+
%res = llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> i32
469+
llvm.return
470+
}

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,32 +1906,6 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
19061906

19071907
// -----
19081908

1909-
llvm.func @fn()
1910-
1911-
// CHECK-LABEL: @call_branch_weights
1912-
llvm.func @call_branch_weights() {
1913-
// CHECK: !prof ![[NODE:[0-9]+]]
1914-
llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
1915-
llvm.return
1916-
}
1917-
1918-
// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
1919-
1920-
// -----
1921-
1922-
llvm.func @fn() -> i32
1923-
1924-
// CHECK-LABEL: @call_branch_weights
1925-
llvm.func @call_branch_weights() {
1926-
// CHECK: !prof ![[NODE:[0-9]+]]
1927-
%res = llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> i32
1928-
llvm.return
1929-
}
1930-
1931-
// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
1932-
1933-
// -----
1934-
19351909
llvm.func @foo()
19361910
llvm.func @__gxx_personality_v0(...) -> i32
19371911

0 commit comments

Comments
 (0)