Skip to content

Conversation

scui-ibm
Copy link
Contributor

This is to modify function updateBranchWeights to consider loops with multiple loop exits.

Also, the current heuristics to calculate the new weights of the two exists are based on the scaling of the original weights with fixed ZeroTripCountWeights. To avoid the scaling of the original weights, the weights of new exists will be calculated to be proportional to the weights of the corresponding basic blocks.

@scui-ibm scui-ibm self-assigned this Oct 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Shimin Cui (scui-ibm)

Changes

This is to modify function updateBranchWeights to consider loops with multiple loop exits.

Also, the current heuristics to calculate the new weights of the two exists are based on the scaling of the original weights with fixed ZeroTripCountWeights. To avoid the scaling of the original weights, the weights of new exists will be calculated to be proportional to the weights of the corresponding basic blocks.


Full diff: https://github.com/llvm/llvm-project/pull/164006.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+66-34)
  • (modified) llvm/test/Transforms/LoopRotate/update-branch-weights.ll (+4-4)
  • (added) llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights (+115)
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 0c8d6fa47b9ae..52ae973935c88 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -46,9 +46,6 @@ STATISTIC(NumInstrsHoisted,
 STATISTIC(NumInstrsDuplicated,
           "Number of instructions cloned into loop preheader");
 
-// Probability that a rotated loop has zero trip count / is never entered.
-static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};
-
 namespace {
 /// A simple loop rotation transformation.
 class LoopRotate {
@@ -200,7 +197,8 @@ static bool profitableToRotateLoopExitingLatch(Loop *L) {
   return false;
 }
 
-static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
+static void updateBranchWeights(Loop *L, BranchInst &PreHeaderBI,
+                                BranchInst &LoopBI,
                                 bool HasConditionalPreHeader,
                                 bool SuccsSwapped) {
   MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
@@ -218,12 +216,49 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
   if (Weights.size() != 2)
     return;
   uint32_t OrigLoopExitWeight = Weights[0];
-  uint32_t OrigLoopBackedgeWeight = Weights[1];
+  uint32_t OrigLoopEnterWeight = Weights[1];
 
   if (SuccsSwapped)
-    std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
+    std::swap(OrigLoopExitWeight, OrigLoopEnterWeight);
+
+  // For a multiple-exit loop, find the total weight of other exits.
+  uint32_t OtherLoopExitWeight = 0;
+  SmallVector<BasicBlock *, 16> ExitingBlocks;
+  L->getExitingBlocks(ExitingBlocks);
+  for (BasicBlock *ExitingBB : ExitingBlocks) {
+    Instruction *TI = ExitingBB->getTerminator();
+    if (TI == &LoopBI)
+      continue;
+
+    if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) ||
+          isa<IndirectBrInst>(TI) || isa<InvokeInst>(TI) ||
+          isa<CallBrInst>(TI)))
+      continue;
+
+    MDNode *WeightsNode = getValidBranchWeightMDNode(*TI);
+    if (!WeightsNode)
+      continue;
+
+    SmallVector<uint32_t, 2> Weights;
+    extractBranchWeights(WeightsNode, Weights);
+    for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
+      BasicBlock *Exit = TI->getSuccessor(I);
+      if (L->contains(Exit))
+        continue;
+
+      OtherLoopExitWeight += Weights[I];
+    }
+  }
 
-  // Update branch weights. Consider the following edge-counts:
+  // Adjust OtherLoopExitWeight as it should not be larger than the loop enter
+  // weight.
+  if (OtherLoopExitWeight > OrigLoopEnterWeight)
+    OtherLoopExitWeight = OrigLoopEnterWeight;
+
+  uint32_t OrigLoopBackedgeWeight = OrigLoopEnterWeight - OtherLoopExitWeight;
+
+  // Update branch weights. Consider the following edge-counts (z for multiple
+  // exit loop):
   //
   //    |  |--------             |
   //    V  V       |             V
@@ -231,16 +266,18 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
   //   |       |   |            |     |
   //  x|      y|   |  becomes:  |   y0|  |-----
   //   V       V   |            |     V  V    |
-  // Exit    Loop  |            |    Loop     |
-  //           |   |            |   Br i1 ... |
+  // Exit <- Loop  |            |    Loop     |
+  //      z    |   |            |   Br i1 ... |
   //           -----            |   |      |  |
   //                          x0| x1|   y1 |  |
   //                            V   V      ----
-  //                            Exit
+  //                            Exit  <----|
+  //                                    z
   //
   // The following must hold:
   //  -  x == x0 + x1        # counts to "exit" must stay the same.
-  //  - y0 == x - x0 == x1   # how often loop was entered at all.
+  //  - y0 == x - x0 + z     # how often loop was entered at all.
+  //       == x1 + z
   //  - y1 == y - y0         # How often loop was repeated (after first iter.).
   //
   // We cannot generally deduce how often we had a zero-trip count loop so we
@@ -255,19 +292,12 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     if (HasConditionalPreHeader) {
       // Here we cannot know how many 0-trip count loops we have, so we guess:
       if (OrigLoopBackedgeWeight >= OrigLoopExitWeight) {
-        // If the loop count is bigger than the exit count then we set
-        // probabilities as if 0-trip count nearly never happens.
-        ExitWeight0 = ZeroTripCountWeights[0];
-        // Scale up counts if necessary so we can match `ZeroTripCountWeights`
-        // for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
-        while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
-          // ... but don't overflow.
-          uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
-          if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
-              (OrigLoopExitWeight & HighBit) != 0)
-            break;
-          OrigLoopBackedgeWeight <<= 1;
-          OrigLoopExitWeight <<= 1;
+        ExitWeight0 =
+            (OrigLoopExitWeight * (OrigLoopExitWeight + OtherLoopExitWeight)) /
+            (OrigLoopExitWeight + OrigLoopEnterWeight);
+        // Minimum ExitWeight0 1
+        if (ExitWeight0 == 0) {
+          ExitWeight0 = 1;
         }
       } else {
         // If there's a higher exit-count than backedge-count then we set
@@ -280,36 +310,38 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
       // weight collected by sampling-based PGO may be not very accurate due to
       // sampling. Therefore this workaround is required here to avoid underflow
       // of unsigned in following update of branch weight.
-      if (OrigLoopExitWeight > OrigLoopBackedgeWeight)
+      if (OrigLoopExitWeight > OrigLoopBackedgeWeight) {
         OrigLoopBackedgeWeight = OrigLoopExitWeight;
+        OrigLoopEnterWeight = OrigLoopBackedgeWeight + OtherLoopExitWeight;
+      }
     }
     assert(OrigLoopExitWeight >= ExitWeight0 && "Bad branch weight");
     ExitWeight1 = OrigLoopExitWeight - ExitWeight0;
-    EnterWeight = ExitWeight1;
-    assert(OrigLoopBackedgeWeight >= EnterWeight && "Bad branch weight");
-    LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight;
+    EnterWeight = ExitWeight1 + OtherLoopExitWeight;
+    assert(OrigLoopEnterWeight >= EnterWeight && "Bad branch weight");
+    LoopBackWeight = OrigLoopEnterWeight - EnterWeight;
   } else if (OrigLoopExitWeight == 0) {
     if (OrigLoopBackedgeWeight == 0) {
       // degenerate case... keep everything zero...
       ExitWeight0 = 0;
       ExitWeight1 = 0;
-      EnterWeight = 0;
+      EnterWeight = OtherLoopExitWeight;
       LoopBackWeight = 0;
     } else {
       // Special case "LoopExitWeight == 0" weights which behaves like an
-      // endless where we don't want loop-enttry (y0) to be the same as
+      // endless where we don't want loop-entry (y0) to be the same as
       // loop-exit (x1).
       ExitWeight0 = 0;
       ExitWeight1 = 0;
-      EnterWeight = 1;
+      EnterWeight = (OtherLoopExitWeight != 0) ? OtherLoopExitWeight : 1;
       LoopBackWeight = OrigLoopBackedgeWeight;
     }
   } else {
     // loop is never entered.
     assert(OrigLoopBackedgeWeight == 0 && "remaining case is backedge zero");
-    ExitWeight0 = 1;
+    ExitWeight0 = OrigLoopExitWeight;
     ExitWeight1 = 1;
-    EnterWeight = 0;
+    EnterWeight = OtherLoopExitWeight;
     LoopBackWeight = 0;
   }
 
@@ -748,7 +780,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
       !isa<ConstantInt>(Cond) ||
       PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;
 
-  updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
+  updateBranchWeights(L, *PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
 
   if (HasConditionalPreHeader) {
     // The conditional branch can't be folded, handle the general case.
diff --git a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
index 9a1f36ec5ff2b..77157a1f45e8a 100644
--- a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
+++ b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
@@ -70,9 +70,9 @@ outer_loop_exit:
 
 ; BFI_AFTER-LABEL: block-frequency-info: func1
 ; BFI_AFTER: - entry: {{.*}} count = 1024
-; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1016
+; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 512
 ; BFI_AFTER: - loop_body: {{.*}} count = 20480
-; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1016
+; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 512
 ; BFI_AFTER: - loop_exit: {{.*}} count = 1024
 
 ; IR-LABEL: define void @func1
@@ -285,8 +285,8 @@ loop_exit:
 
 ; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
 ; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
-; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 127, i32 1}
-; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
+; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 1, i32 1}
+; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 39, i32 1}
 ; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
 ; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}
 ; IR: [[PROF_FUNC3_0]] = !{!"branch_weights", i32 0, i32 1}
diff --git a/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights b/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
new file mode 100644
index 0000000000000..1a62a47d65810
--- /dev/null
+++ b/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
@@ -0,0 +1,115 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt < %s -passes='loop(loop-rotate)' -S | FileCheck %s
+
+@g = global i64 0
+
+define void @func_branch_weight(i64 %n) !prof !0 {
+; CHECK-LABEL: define void @func_branch_weight(
+; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0:![0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i64 0, [[N]]
+; CHECK-NEXT:    br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1:![0-9]+]]
+; CHECK:       [[LOOP_BODY_LR_PH]]:
+; CHECK-NEXT:    br label %[[LOOP_BODY:.*]]
+; CHECK:       [[LOOP_HEADER:.*]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF2:![0-9]+]]
+; CHECK:       [[LOOP_BODY]]:
+; CHECK-NEXT:    [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
+; CHECK-NEXT:    [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
+; CHECK-NEXT:    [[GI:%.*]] = load i64, ptr [[GP]], align 8
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
+; CHECK-NEXT:    [[I_INC]] = add i64 [[I2]], 1
+; CHECK-NEXT:    br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF3:![0-9]+]]
+; CHECK:       [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_EXIT]]:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop_header
+
+loop_header:
+  %i = phi i64 [0, %entry], [%i_inc, %if_then]
+  %cmp = icmp slt i64 %i, %n
+  br i1 %cmp, label %loop_exit, label %loop_body, !prof !1
+
+loop_body:
+  %gp = getelementptr inbounds i8, ptr @g, i64 %i
+  %gi = load i64, ptr %gp, align 8
+  %cmp.not = icmp eq i64 %gi, 0
+  br i1 %cmp.not, label %if_then, label %loop_exit, !prof !2
+
+if_then:
+  %i_inc = add i64 %i, 1
+  br label %loop_header
+
+loop_exit:
+  ret void
+}
+
+
+define void @func_zero_backage_weight(i64 %n) !prof !0 {
+; CHECK-LABEL: define void @func_zero_backage_weight(
+; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i64 0, [[N]]
+; CHECK-NEXT:    br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1]]
+; CHECK:       [[LOOP_BODY_LR_PH]]:
+; CHECK-NEXT:    br label %[[LOOP_BODY:.*]]
+; CHECK:       [[LOOP_HEADER:.*]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF4:![0-9]+]]
+; CHECK:       [[LOOP_BODY]]:
+; CHECK-NEXT:    [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
+; CHECK-NEXT:    [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
+; CHECK-NEXT:    [[GI:%.*]] = load i64, ptr [[GP]], align 8
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
+; CHECK-NEXT:    [[I_INC]] = add i64 [[I2]], 1
+; CHECK-NEXT:    br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF5:![0-9]+]]
+; CHECK:       [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_EXIT]]:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop_header
+
+loop_header:
+  %i = phi i64 [0, %entry], [%i_inc, %if_then]
+  %cmp = icmp slt i64 %i, %n
+  br i1 %cmp, label %loop_exit, label %loop_body, !prof !3
+
+loop_body:
+  %gp = getelementptr inbounds i8, ptr @g, i64 %i
+  %gi = load i64, ptr %gp, align 8
+  %cmp.not = icmp eq i64 %gi, 0
+  br i1 %cmp.not, label %if_then, label %loop_exit, !prof !4
+
+if_then:
+  %i_inc = add i64 %i, 1
+  br label %loop_header
+
+loop_exit:
+  ret void
+}
+
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 200, i32 900}
+!2 = !{!"branch_weights", i32 100, i32 800}
+!3 = !{!"branch_weights", i32 100, i32 900}
+!4 = !{!"branch_weights", i32 0, i32 900}
+;.
+; CHECK: [[PROF0]] = !{!"function_entry_count", i64 1000}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 100, i32 900}
+; CHECK: [[PROF2]] = !{!"branch_weights", i32 100, i32 0}
+; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 800}
+; CHECK: [[PROF4]] = !{!"branch_weights", i32 1, i32 0}
+; CHECK: [[PROF5]] = !{!"branch_weights", i32 0, i32 900}
+;.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants