1818#include " llvm/ADT/SmallVector.h"
1919#include " llvm/IR/Metadata.h"
2020#include " llvm/Support/Compiler.h"
21+ #include < cstddef>
22+ #include < type_traits>
2123
2224namespace llvm {
2325struct MDProfLabels {
@@ -216,9 +218,13 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
216218// / branch weights B1 and B2, respectively. In both B1 and B2, the first
217219// / position (index 0) is for the 'true' branch, and the second position (index
218220// / 1) is for the 'false' branch.
221+ template <typename T1, typename T2,
222+ typename = typename std::enable_if<
223+ std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
224+ sizeof (T1) <= sizeof (uint64_t ) && sizeof (T2) <= sizeof (uint64_t )>>
219225inline SmallVector<uint64_t , 2 >
220- getDisjunctionWeights (const SmallVector<uint32_t , 2 > &B1,
221- const SmallVector<uint32_t , 2 > &B2) {
226+ getDisjunctionWeights (const SmallVector<T1 , 2 > &B1,
227+ const SmallVector<T2 , 2 > &B2) {
222228 // For the first conditional branch, the probability the "true" case is taken
223229 // is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
224230 // p(not b1) = B1[1] / (B1[0] + B1[1]).
@@ -235,8 +241,8 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
235241 // the product of sums, the subtracted one cancels out).
236242 assert (B1.size () == 2 );
237243 assert (B2.size () == 2 );
238- auto FalseWeight = B1[1 ] * B2[1 ];
239- auto TrueWeight = B1[0 ] * B2[0 ] + B1[ 0 ] * B2[1 ] + B1[1 ] * B2[0 ];
244+ uint64_t FalseWeight = B1[1 ] * B2[1 ];
245+ uint64_t TrueWeight = B1[0 ] * ( B2[0 ] + B2[1 ]) + B1[1 ] * B2[0 ];
240246 return {TrueWeight, FalseWeight};
241247}
242248} // namespace llvm
0 commit comments