1818#include " llvm/ADT/SmallVector.h"
1919#include " llvm/IR/Metadata.h"
2020#include " llvm/Support/Compiler.h"
21+ #include < type_traits>
2122
2223namespace llvm {
2324struct MDProfLabels {
@@ -216,9 +217,12 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
216217// / branch weights B1 and B2, respectively. In both B1 and B2, the first
217218// / position (index 0) is for the 'true' branch, and the second position (index
218219// / 1) is for the 'false' branch.
220+ template <typename T1, typename T2,
221+ typename = typename std::enable_if<std::is_arithmetic_v<T1> &&
222+ std::is_arithmetic_v<T2>>>
219223inline SmallVector<uint64_t , 2 >
220- getDisjunctionWeights (const SmallVector<uint32_t , 2 > &B1,
221- const SmallVector<uint32_t , 2 > &B2) {
224+ getDisjunctionWeights (const SmallVector<T1 , 2 > &B1,
225+ const SmallVector<T2 , 2 > &B2) {
222226 // For the first conditional branch, the probability the "true" case is taken
223227 // is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
224228 // p(not b1) = B1[1] / (B1[0] + B1[1]).
@@ -235,8 +239,8 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
235239 // the product of sums, the subtracted one cancels out).
236240 assert (B1.size () == 2 );
237241 assert (B2.size () == 2 );
238- auto FalseWeight = B1[1 ] * B2[1 ];
239- auto TrueWeight = B1[0 ] * B2[0 ] + B1[ 0 ] * B2[1 ] + B1[1 ] * B2[0 ];
242+ uint64_t FalseWeight = B1[1 ] * B2[1 ];
243+ uint64_t TrueWeight = B1[0 ] * ( B2[0 ] + B2[1 ]) + B1[1 ] * B2[0 ];
240244 return {TrueWeight, FalseWeight};
241245}
242246} // namespace llvm
0 commit comments