|
1 | 1 | #include "openvic-simulation/types/fixed_point/FixedPoint.hpp" |
2 | 2 | #include "openvic-simulation/core/random/WeightedSampling.hpp" |
3 | 3 |
|
4 | | -#include "core/random/ExtendedMath.hpp" |
5 | | - |
6 | 4 | #include <cstdint> |
7 | 5 | #include <limits> |
| 6 | +#include <boost/int128.hpp> |
8 | 7 |
|
9 | 8 | #include <snitch/snitch_macros_check.hpp> |
10 | 9 | #include <snitch/snitch_macros_test_case.hpp> |
11 | 10 |
|
12 | 11 | using namespace OpenVic; |
13 | | -using namespace OpenVic::testing; |
14 | 12 |
|
15 | 13 | constexpr uint32_t max_random_value = std::numeric_limits<uint32_t>().max(); |
16 | 14 |
|
@@ -46,16 +44,12 @@ TEST_CASE("WeightedSampling weights", "[WeightedSampling]") { |
46 | 44 | fixed_point_t cumulative_weight = 0; |
47 | 45 | for (size_t i = 0; i < weights.size(); ++i) { |
48 | 46 | cumulative_weight += weights[i]; |
49 | | - const Int96DivisionResult random_value = portable_int96_div_int64( |
50 | | - portable_int64_mult_uint32_96bit( |
51 | | - cumulative_weight.get_raw_value(), |
52 | | - max_random_value |
53 | | - ), |
54 | | - weights_sum.get_raw_value() |
55 | | - ); |
56 | | - assert(!random_value.quotient_overflow); |
| 47 | + const boost::int128::int128_t cumulative_weight_128 = cumulative_weight.get_raw_value(); |
| 48 | + const boost::int128::int128_t max_random_value_128 = max_random_value; |
| 49 | + const boost::int128::int128_t weights_sum_128 = weights_sum.get_raw_value(); |
| 50 | + const boost::int128::int128_t random_value = cumulative_weight_128 * max_random_value_128 / weights_sum_128; |
57 | 51 | CHECK(sample_weighted_index( |
58 | | - static_cast<uint32_t>(random_value.quotient), |
| 52 | + static_cast<uint32_t>(random_value), |
59 | 53 | weights, |
60 | 54 | weights_sum |
61 | 55 | ) == i); |
|
0 commit comments