Skip to content

Commit 49b9823

Browse files
committed
[NFC] Generalize the arithmetic type for getDisjunctionWeights
1 parent 17e2641 commit 49b9823

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/ADT/SmallVector.h"
1919
#include "llvm/IR/Metadata.h"
2020
#include "llvm/Support/Compiler.h"
21+
#include <type_traits>
2122

2223
namespace llvm {
2324
struct MDProfLabels {
@@ -216,9 +217,12 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
216217
/// branch weights B1 and B2, respectively. In both B1 and B2, the first
217218
/// position (index 0) is for the 'true' branch, and the second position (index
218219
/// 1) is for the 'false' branch.
220+
template <typename T1, typename T2,
221+
typename = typename std::enable_if<std::is_arithmetic_v<T1> &&
222+
std::is_arithmetic_v<T2>>>
219223
inline SmallVector<uint64_t, 2>
220-
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
221-
const SmallVector<uint32_t, 2> &B2) {
224+
getDisjunctionWeights(const SmallVector<T1, 2> &B1,
225+
const SmallVector<T2, 2> &B2) {
222226
// For the first conditional branch, the probability the "true" case is taken
223227
// is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
224228
// p(not b1) = B1[1] / (B1[0] + B1[1]).
@@ -235,8 +239,8 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
235239
// the product of sums, the subtracted one cancels out).
236240
assert(B1.size() == 2);
237241
assert(B2.size() == 2);
238-
auto FalseWeight = B1[1] * B2[1];
239-
auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
242+
uint64_t FalseWeight = B1[1] * B2[1];
243+
uint64_t TrueWeight = B1[0] * (B2[0] + B2[1]) + B1[1] * B2[0];
240244
return {TrueWeight, FalseWeight};
241245
}
242246
} // namespace llvm

0 commit comments

Comments
 (0)