Skip to content

Conversation

@mtrofin
Copy link
Member

@mtrofin mtrofin commented Nov 11, 2025

No description provided.

Copy link
Member Author

mtrofin commented Nov 11, 2025

@mtrofin mtrofin marked this pull request as ready for review November 11, 2025 22:44
@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2025

@llvm/pr-subscribers-llvm-ir

Author: Mircea Trofin (mtrofin)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/167593.diff

1 Files Affected:

  • (modified) llvm/include/llvm/IR/ProfDataUtils.h (+8-4)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index a7bcbf010d1bf..fade7a2dbac2b 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/Compiler.h"
+#include <type_traits>
 
 namespace llvm {
 struct MDProfLabels {
@@ -216,9 +217,12 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
 /// branch weights B1 and B2, respectively. In both B1 and B2, the first
 /// position (index 0) is for the 'true' branch, and the second position (index
 /// 1) is for the 'false' branch.
+template <typename T1, typename T2,
+          typename = typename std::enable_if<std::is_arithmetic_v<T1> &&
+                                             std::is_arithmetic_v<T2>>>
 inline SmallVector<uint64_t, 2>
-getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
-                      const SmallVector<uint32_t, 2> &B2) {
+getDisjunctionWeights(const SmallVector<T1, 2> &B1,
+                      const SmallVector<T2, 2> &B2) {
   // For the first conditional branch, the probability the "true" case is taken
   // is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
   // p(not b1) = B1[1] / (B1[0] + B1[1]).
@@ -235,8 +239,8 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
   // the product of sums, the subtracted one cancels out).
   assert(B1.size() == 2);
   assert(B2.size() == 2);
-  auto FalseWeight = B1[1] * B2[1];
-  auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+  uint64_t FalseWeight = B1[1] * B2[1];
+  uint64_t TrueWeight = B1[0] * (B2[0] + B2[1]) + B1[1] * B2[0];
   return {TrueWeight, FalseWeight};
 }
 } // namespace llvm

assert(B2.size() == 2);
auto FalseWeight = B1[1] * B2[1];
auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
uint64_t FalseWeight = B1[1] * B2[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for changing this? Previously it would default to the integer type, but now it's fixed.

Shouldn't be a big issue for uint64_t specifically, but maybe a good idea to add a static assert that the bit width of T1/T2 is less than or equal to 64?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid overflow. I missed the fact I should have declared it uint64_t explicitly the first time around.

added static asserts.

@mtrofin mtrofin force-pushed the users/mtrofin/11-11-_nfc_generalize_the_arithmetic_type_for_getdisjunctionweights_ branch from 49b9823 to 967055b Compare November 12, 2025 01:02
Copy link
Member Author

mtrofin commented Nov 12, 2025

Merge activity

  • Nov 12, 1:43 AM UTC: A user started a stack merge that includes this pull request via Graphite.
  • Nov 12, 1:45 AM UTC: @mtrofin merged this pull request with Graphite.

@mtrofin mtrofin merged commit a863fd8 into main Nov 12, 2025
7 of 9 checks passed
@mtrofin mtrofin deleted the users/mtrofin/11-11-_nfc_generalize_the_arithmetic_type_for_getdisjunctionweights_ branch November 12, 2025 01:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants