Skip to content

Commit 967055b

Browse files
committed
[NFC] Generalize the arithmetic type for getDisjunctionWeights
1 parent 17e2641 commit 967055b

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "llvm/ADT/SmallVector.h"
1919
#include "llvm/IR/Metadata.h"
2020
#include "llvm/Support/Compiler.h"
21+
#include <cstddef>
22+
#include <type_traits>
2123

2224
namespace llvm {
2325
struct MDProfLabels {
@@ -216,9 +218,13 @@ LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
216218
/// branch weights B1 and B2, respectively. In both B1 and B2, the first
217219
/// position (index 0) is for the 'true' branch, and the second position (index
218220
/// 1) is for the 'false' branch.
221+
template <typename T1, typename T2,
222+
typename = typename std::enable_if<
223+
std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
224+
sizeof(T1) <= sizeof(uint64_t) && sizeof(T2) <= sizeof(uint64_t)>>
219225
inline SmallVector<uint64_t, 2>
220-
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
221-
const SmallVector<uint32_t, 2> &B2) {
226+
getDisjunctionWeights(const SmallVector<T1, 2> &B1,
227+
const SmallVector<T2, 2> &B2) {
222228
// For the first conditional branch, the probability the "true" case is taken
223229
// is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
224230
// p(not b1) = B1[1] / (B1[0] + B1[1]).
@@ -235,8 +241,8 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
235241
// the product of sums, the subtracted one cancels out).
236242
assert(B1.size() == 2);
237243
assert(B2.size() == 2);
238-
auto FalseWeight = B1[1] * B2[1];
239-
auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
244+
uint64_t FalseWeight = B1[1] * B2[1];
245+
uint64_t TrueWeight = B1[0] * (B2[0] + B2[1]) + B1[1] * B2[0];
240246
return {TrueWeight, FalseWeight};
241247
}
242248
} // namespace llvm

0 commit comments

Comments
 (0)