Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 66 additions & 34 deletions llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ STATISTIC(NumInstrsHoisted,
STATISTIC(NumInstrsDuplicated,
"Number of instructions cloned into loop preheader");

// Probability that a rotated loop has zero trip count / is never entered.
static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};

namespace {
/// A simple loop rotation transformation.
class LoopRotate {
Expand Down Expand Up @@ -200,7 +197,8 @@ static bool profitableToRotateLoopExitingLatch(Loop *L) {
return false;
}

static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
static void updateBranchWeights(Loop *L, BranchInst &PreHeaderBI,
BranchInst &LoopBI,
bool HasConditionalPreHeader,
bool SuccsSwapped) {
MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
Expand All @@ -218,29 +216,68 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
if (Weights.size() != 2)
return;
uint32_t OrigLoopExitWeight = Weights[0];
uint32_t OrigLoopBackedgeWeight = Weights[1];
uint32_t OrigLoopEnterWeight = Weights[1];

if (SuccsSwapped)
std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
std::swap(OrigLoopExitWeight, OrigLoopEnterWeight);

// For a multiple-exit loop, find the total weight of other exits.
uint32_t OtherLoopExitWeight = 0;
SmallVector<BasicBlock *, 16> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
for (BasicBlock *ExitingBB : ExitingBlocks) {
Instruction *TI = ExitingBB->getTerminator();
if (TI == &LoopBI)
continue;

if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) ||
isa<IndirectBrInst>(TI) || isa<InvokeInst>(TI) ||
isa<CallBrInst>(TI)))
continue;

MDNode *WeightsNode = getValidBranchWeightMDNode(*TI);
if (!WeightsNode)
continue;

SmallVector<uint32_t, 2> Weights;
extractBranchWeights(WeightsNode, Weights);
for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
BasicBlock *Exit = TI->getSuccessor(I);
if (L->contains(Exit))
continue;

OtherLoopExitWeight += Weights[I];
}
}

// Update branch weights. Consider the following edge-counts:
// Adjust OtherLoopExitWeight as it should not be larger than the loop enter
// weight.
if (OtherLoopExitWeight > OrigLoopEnterWeight)
OtherLoopExitWeight = OrigLoopEnterWeight;

uint32_t OrigLoopBackedgeWeight = OrigLoopEnterWeight - OtherLoopExitWeight;

// Update branch weights. Consider the following edge-counts (z for multiple
// exit loop):
//
// | |-------- |
// V V | V
// Br i1 ... | Br i1 ...
// | | | | |
// x| y| | becomes: | y0| |-----
// V V | | V V |
// Exit Loop | | Loop |
// | | | Br i1 ... |
// Exit <- Loop | | Loop |
// z | | | Br i1 ... |
// ----- | | | |
// x0| x1| y1 | |
// V V ----
// Exit
// Exit <----|
// z
//
// The following must hold:
// - x == x0 + x1 # counts to "exit" must stay the same.
// - y0 == x - x0 == x1 # how often loop was entered at all.
// - y0 == x - x0 + z # how often loop was entered at all.
// == x1 + z
// - y1 == y - y0 # How often loop was repeated (after first iter.).
//
// We cannot generally deduce how often we had a zero-trip count loop so we
Expand All @@ -255,19 +292,12 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
if (HasConditionalPreHeader) {
// Here we cannot know how many 0-trip count loops we have, so we guess:
if (OrigLoopBackedgeWeight >= OrigLoopExitWeight) {
// If the loop count is bigger than the exit count then we set
// probabilities as if 0-trip count nearly never happens.
ExitWeight0 = ZeroTripCountWeights[0];
// Scale up counts if necessary so we can match `ZeroTripCountWeights`
// for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
// ... but don't overflow.
uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
(OrigLoopExitWeight & HighBit) != 0)
break;
OrigLoopBackedgeWeight <<= 1;
OrigLoopExitWeight <<= 1;
ExitWeight0 =
(OrigLoopExitWeight * (OrigLoopExitWeight + OtherLoopExitWeight)) /
(OrigLoopExitWeight + OrigLoopEnterWeight);
// Minimum ExitWeight0 1
if (ExitWeight0 == 0) {
ExitWeight0 = 1;
}
} else {
// If there's a higher exit-count than backedge-count then we set
Expand All @@ -280,36 +310,38 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
// weight collected by sampling-based PGO may be not very accurate due to
// sampling. Therefore this workaround is required here to avoid underflow
// of unsigned in following update of branch weight.
if (OrigLoopExitWeight > OrigLoopBackedgeWeight)
if (OrigLoopExitWeight > OrigLoopBackedgeWeight) {
OrigLoopBackedgeWeight = OrigLoopExitWeight;
OrigLoopEnterWeight = OrigLoopBackedgeWeight + OtherLoopExitWeight;
}
}
assert(OrigLoopExitWeight >= ExitWeight0 && "Bad branch weight");
ExitWeight1 = OrigLoopExitWeight - ExitWeight0;
EnterWeight = ExitWeight1;
assert(OrigLoopBackedgeWeight >= EnterWeight && "Bad branch weight");
LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight;
EnterWeight = ExitWeight1 + OtherLoopExitWeight;
assert(OrigLoopEnterWeight >= EnterWeight && "Bad branch weight");
LoopBackWeight = OrigLoopEnterWeight - EnterWeight;
} else if (OrigLoopExitWeight == 0) {
if (OrigLoopBackedgeWeight == 0) {
// degenerate case... keep everything zero...
ExitWeight0 = 0;
ExitWeight1 = 0;
EnterWeight = 0;
EnterWeight = OtherLoopExitWeight;
LoopBackWeight = 0;
} else {
// Special case "LoopExitWeight == 0" weights which behaves like an
// endless where we don't want loop-enttry (y0) to be the same as
// endless where we don't want loop-entry (y0) to be the same as
// loop-exit (x1).
ExitWeight0 = 0;
ExitWeight1 = 0;
EnterWeight = 1;
EnterWeight = (OtherLoopExitWeight != 0) ? OtherLoopExitWeight : 1;
LoopBackWeight = OrigLoopBackedgeWeight;
}
} else {
// loop is never entered.
assert(OrigLoopBackedgeWeight == 0 && "remaining case is backedge zero");
ExitWeight0 = 1;
ExitWeight0 = OrigLoopExitWeight;
ExitWeight1 = 1;
EnterWeight = 0;
EnterWeight = OtherLoopExitWeight;
LoopBackWeight = 0;
}

Expand Down Expand Up @@ -748,7 +780,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
!isa<ConstantInt>(Cond) ||
PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;

updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
updateBranchWeights(L, *PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);

if (HasConditionalPreHeader) {
// The conditional branch can't be folded, handle the general case.
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Transforms/LoopRotate/update-branch-weights.ll
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ outer_loop_exit:

; BFI_AFTER-LABEL: block-frequency-info: func1
; BFI_AFTER: - entry: {{.*}} count = 1024
; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1016
; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 512
; BFI_AFTER: - loop_body: {{.*}} count = 20480
; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1016
; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 512
; BFI_AFTER: - loop_exit: {{.*}} count = 1024

; IR-LABEL: define void @func1
Expand Down Expand Up @@ -285,8 +285,8 @@ loop_exit:

; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 127, i32 1}
; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 1, i32 1}
; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 39, i32 1}
; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}
; IR: [[PROF_FUNC3_0]] = !{!"branch_weights", i32 0, i32 1}
Expand Down
115 changes: 115 additions & 0 deletions llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
; RUN: opt < %s -passes='loop(loop-rotate)' -S | FileCheck %s

@g = global i64 0

define void @func_branch_weight(i64 %n) !prof !0 {
; CHECK-LABEL: define void @func_branch_weight(
; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0:![0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i64 0, [[N]]
; CHECK-NEXT: br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1:![0-9]+]]
; CHECK: [[LOOP_BODY_LR_PH]]:
; CHECK-NEXT: br label %[[LOOP_BODY:.*]]
; CHECK: [[LOOP_HEADER:.*]]:
; CHECK-NEXT: [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
; CHECK-NEXT: br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF2:![0-9]+]]
; CHECK: [[LOOP_BODY]]:
; CHECK-NEXT: [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
; CHECK-NEXT: [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
; CHECK-NEXT: [[GI:%.*]] = load i64, ptr [[GP]], align 8
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
; CHECK-NEXT: [[I_INC]] = add i64 [[I2]], 1
; CHECK-NEXT: br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF3:![0-9]+]]
; CHECK: [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
; CHECK-NEXT: br label %[[LOOP_EXIT]]
; CHECK: [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
; CHECK-NEXT: br label %[[LOOP_EXIT]]
; CHECK: [[LOOP_EXIT]]:
; CHECK-NEXT: ret void
;
entry:
br label %loop_header

loop_header:
%i = phi i64 [0, %entry], [%i_inc, %if_then]
%cmp = icmp slt i64 %i, %n
br i1 %cmp, label %loop_exit, label %loop_body, !prof !1

loop_body:
%gp = getelementptr inbounds i8, ptr @g, i64 %i
%gi = load i64, ptr %gp, align 8
%cmp.not = icmp eq i64 %gi, 0
br i1 %cmp.not, label %if_then, label %loop_exit, !prof !2

if_then:
%i_inc = add i64 %i, 1
br label %loop_header

loop_exit:
ret void
}


define void @func_zero_backage_weight(i64 %n) !prof !0 {
; CHECK-LABEL: define void @func_zero_backage_weight(
; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i64 0, [[N]]
; CHECK-NEXT: br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1]]
; CHECK: [[LOOP_BODY_LR_PH]]:
; CHECK-NEXT: br label %[[LOOP_BODY:.*]]
; CHECK: [[LOOP_HEADER:.*]]:
; CHECK-NEXT: [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
; CHECK-NEXT: br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF4:![0-9]+]]
; CHECK: [[LOOP_BODY]]:
; CHECK-NEXT: [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
; CHECK-NEXT: [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
; CHECK-NEXT: [[GI:%.*]] = load i64, ptr [[GP]], align 8
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
; CHECK-NEXT: [[I_INC]] = add i64 [[I2]], 1
; CHECK-NEXT: br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF5:![0-9]+]]
; CHECK: [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
; CHECK-NEXT: br label %[[LOOP_EXIT]]
; CHECK: [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
; CHECK-NEXT: br label %[[LOOP_EXIT]]
; CHECK: [[LOOP_EXIT]]:
; CHECK-NEXT: ret void
;
entry:
br label %loop_header

loop_header:
%i = phi i64 [0, %entry], [%i_inc, %if_then]
%cmp = icmp slt i64 %i, %n
br i1 %cmp, label %loop_exit, label %loop_body, !prof !3

loop_body:
%gp = getelementptr inbounds i8, ptr @g, i64 %i
%gi = load i64, ptr %gp, align 8
%cmp.not = icmp eq i64 %gi, 0
br i1 %cmp.not, label %if_then, label %loop_exit, !prof !4

if_then:
%i_inc = add i64 %i, 1
br label %loop_header

loop_exit:
ret void
}

!0 = !{!"function_entry_count", i64 1000}
!1 = !{!"branch_weights", i32 200, i32 900}
!2 = !{!"branch_weights", i32 100, i32 800}
!3 = !{!"branch_weights", i32 100, i32 900}
!4 = !{!"branch_weights", i32 0, i32 900}
;.
; CHECK: [[PROF0]] = !{!"function_entry_count", i64 1000}
; CHECK: [[PROF1]] = !{!"branch_weights", i32 100, i32 900}
; CHECK: [[PROF2]] = !{!"branch_weights", i32 100, i32 0}
; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 800}
; CHECK: [[PROF4]] = !{!"branch_weights", i32 1, i32 0}
; CHECK: [[PROF5]] = !{!"branch_weights", i32 0, i32 900}
;.