diff --git a/contracts/feature-tests/basic-features/scenarios/math_features.scen.json b/contracts/feature-tests/basic-features/scenarios/math_features.scen.json new file mode 100644 index 0000000000..f34a307252 --- /dev/null +++ b/contracts/feature-tests/basic-features/scenarios/math_features.scen.json @@ -0,0 +1,192 @@ +{ + "name": "math features", + "steps": [ + { + "step": "setState", + "accounts": { + "sc:basic-features": { + "nonce": "0", + "balance": "0", + "code": "mxsc:../output/basic-features.mxsc.json" + }, + "address:an_account": { + "nonce": "0", + "balance": "0" + } + } + }, + { + "step": "scCall", + "id": "weighted_average_equal_weights", + "comment": "(10*1 + 20*1) / (1+1) = 15", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_weighted_average", + "arguments": ["10", "1", "20", "1"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["15"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "weighted_average_unequal_weights", + "comment": "(0*3 + 30*7) / (3+7) = 21", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_weighted_average", + "arguments": ["0", "3", "30", "7"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["21"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "weighted_average_truncates", + "comment": "(0*1 + 10*3) / (1+3) = 30/4 = 7 (truncated)", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_weighted_average", + "arguments": ["0", "1", "10", "3"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["7"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "weighted_average_round_up_equal_weights", + "comment": "(10*1 + 20*1) / (1+1) = 15 (exact, no rounding)", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_weighted_average_round_up", + "arguments": ["10", "1", "20", "1"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["15"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "weighted_average_round_up_rounds_up", + "comment": "(0*1 + 10*3 + 4 - 1) / (1+3) = 33/4 = 8 (rounded up from 7.5)", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_weighted_average_round_up", + "arguments": ["0", "1", "10", "3"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["8"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "weighted_average_round_up_exact", + "comment": "(1*1 + 4*2) / (1+2) = 9/3 = 3", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_weighted_average_round_up", + "arguments": ["1", "1", "4", "2"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["3"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "linear_interpolation_midpoint", + "comment": "(0*(100-50) + 200*(50-0)) / (100-0) = 10000/100 = 100", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_linear_interpolation", + "arguments": ["0", "100", "50", "0", "200"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["100"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "linear_interpolation_at_min", + "comment": "current_in == min_in => returns min_out = 5", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_linear_interpolation", + "arguments": ["0", "10", "0", "5", "15"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["5"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "linear_interpolation_at_max", + "comment": "current_in == max_in => returns max_out = 15", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_linear_interpolation", + "arguments": ["0", "10", "10", "5", "15"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": ["15"], + "status": "0" + } + }, + { + "step": "scCall", + "id": "linear_interpolation_out_of_range", + "comment": "current_in > max_in => sc_panic", + "tx": { + "from": "address:an_account", + "to": "sc:basic-features", + "function": "math_linear_interpolation", + "arguments": ["0", "10", "11", "5", "15"], + "gasLimit": "50,000,000", + "gasPrice": "0" + }, + "expect": { + "out": [], + "status": "4", + "message": "str:current_in out of [min_in, max_in] range", + "gas": "*", + "refund": "*" + } + } + ] +} diff --git a/contracts/feature-tests/basic-features/src/basic_features_main.rs b/contracts/feature-tests/basic-features/src/basic_features_main.rs index 8a4f6b4276..f491b8b91b 100644 --- a/contracts/feature-tests/basic-features/src/basic_features_main.rs +++ b/contracts/feature-tests/basic-features/src/basic_features_main.rs @@ -20,6 +20,7 @@ pub mod managed_buffer_features; pub mod managed_decimal_features; pub mod managed_map_features; pub mod managed_vec_features; +pub mod math_features; pub mod non_zero_features; pub mod small_num_overflow_test_ops; pub mod special_roles_from_system_account; @@ -90,6 +91,7 @@ pub trait BasicFeatures: + storage_mapper_get_at_address::StorageMapperGetAtAddress + managed_decimal_features::ManagedDecimalFeatures + managed_map_features::ManagedMapFeatures + + math_features::MathFeatures { #[init] fn init(&self) {} diff --git a/contracts/feature-tests/basic-features/src/math_features.rs b/contracts/feature-tests/basic-features/src/math_features.rs new file mode 100644 index 0000000000..1e4a366eb2 --- /dev/null +++ b/contracts/feature-tests/basic-features/src/math_features.rs @@ -0,0 +1,40 @@ +use multiversx_sc::imports::*; +use multiversx_sc::math; + +#[multiversx_sc::module] +pub trait MathFeatures { + #[endpoint] + fn math_weighted_average( + &self, + first_value: BigUint, + first_weight: BigUint, + second_value: BigUint, + second_weight: BigUint, + ) -> BigUint { + math::weighted_average(first_value, first_weight, second_value, second_weight) + } + + #[endpoint] + fn math_weighted_average_round_up( + &self, + first_value: BigUint, + first_weight: BigUint, + second_value: BigUint, + second_weight: BigUint, + ) -> BigUint { + math::weighted_average_round_up(first_value, first_weight, second_value, second_weight) + } + + #[endpoint] + fn math_linear_interpolation( + &self, + min_in: BigUint, + max_in: BigUint, + current_in: BigUint, + min_out: BigUint, + max_out: BigUint, + ) -> BigUint { + math::linear_interpolation(min_in, max_in, current_in, min_out, max_out) + .unwrap_or_else(|_| sc_panic!("current_in out of [min_in, max_in] range")) + } +} diff --git a/contracts/feature-tests/basic-features/tests/basic_features_scenario_go_test.rs b/contracts/feature-tests/basic-features/tests/basic_features_scenario_go_test.rs index 163f2010c1..e1a3b5e97f 100644 --- a/contracts/feature-tests/basic-features/tests/basic_features_scenario_go_test.rs +++ b/contracts/feature-tests/basic-features/tests/basic_features_scenario_go_test.rs @@ -341,6 +341,11 @@ fn managed_vec_biguint_push_go() { world().run("scenarios/managed_vec_biguint_push.scen.json"); } +#[test] +fn math_features_go() { + world().run("scenarios/math_features.scen.json"); +} + #[test] fn mmap_get_go() { world().run("scenarios/mmap_get.scen.json"); diff --git a/contracts/feature-tests/basic-features/tests/basic_features_scenario_rs_test.rs b/contracts/feature-tests/basic-features/tests/basic_features_scenario_rs_test.rs index 1be0dcc852..8ca1c07341 100644 --- a/contracts/feature-tests/basic-features/tests/basic_features_scenario_rs_test.rs +++ b/contracts/feature-tests/basic-features/tests/basic_features_scenario_rs_test.rs @@ -363,6 +363,11 @@ fn managed_vec_biguint_push_rs() { world().run("scenarios/managed_vec_biguint_push.scen.json"); } +#[test] +fn math_features_rs() { + world().run("scenarios/math_features.scen.json"); +} + #[test] fn mmap_get_rs() { world().run("scenarios/mmap_get.scen.json"); diff --git a/framework/base/src/lib.rs b/framework/base/src/lib.rs index 963f56b8f2..bd63d4dacd 100644 --- a/framework/base/src/lib.rs +++ b/framework/base/src/lib.rs @@ -27,6 +27,7 @@ pub mod hex_call_data; pub mod io; pub mod log_util; mod macros; +pub mod math; pub mod non_zero_util; pub mod storage; pub mod tuple_util; diff --git a/framework/base/src/math.rs b/framework/base/src/math.rs new file mode 100644 index 0000000000..391f88d24c --- /dev/null +++ b/framework/base/src/math.rs @@ -0,0 +1,7 @@ +/// Only used internally for computing logarithms for ManagedDecimal and BigUint. +pub(crate) mod internal_logarithm_i64; +mod linear_interpolation; +mod weighted_average; + +pub use linear_interpolation::{LinearInterpolationInvalidValuesError, linear_interpolation}; +pub use weighted_average::{weighted_average, weighted_average_round_up}; diff --git a/framework/base/src/types/math_util/logarithm_i64.rs b/framework/base/src/math/internal_logarithm_i64.rs similarity index 100% rename from framework/base/src/types/math_util/logarithm_i64.rs rename to framework/base/src/math/internal_logarithm_i64.rs diff --git a/framework/base/src/math/linear_interpolation.rs b/framework/base/src/math/linear_interpolation.rs new file mode 100644 index 0000000000..638c4fa150 --- /dev/null +++ b/framework/base/src/math/linear_interpolation.rs @@ -0,0 +1,40 @@ +use core::ops::{Add, Div, Mul, Sub}; + +/// Error returned when `current_in` is outside the `[min_in, max_in]` range. +#[derive(Debug)] +pub struct LinearInterpolationInvalidValuesError; + +/// Computes a linearly interpolated output value for a given input within a known range. +/// +/// Given an input range `[min_in, max_in]` and a corresponding output range `[min_out, max_out]`, +/// maps `current_in` proportionally to its position in the output range. +/// +/// Formula: +/// ```text +/// out = (min_out * (max_in - current_in) + max_out * (current_in - min_in)) / (max_in - min_in) +/// ``` +/// +/// Returns [`Err(LinearInterpolationInvalidValuesError)`] if `current_in` is outside `[min_in, max_in]`. +/// +/// See also: +pub fn linear_interpolation( + min_in: T, + max_in: T, + current_in: T, + min_out: T, + max_out: T, +) -> Result +where + T: Add + Sub + Mul + Div + PartialOrd + Clone, +{ + if min_in > max_in || current_in < min_in || current_in > max_in { + return Err(LinearInterpolationInvalidValuesError); + } + + let min_out_weighted = min_out * (max_in.clone() - current_in.clone()); + let max_out_weighted = max_out * (current_in - min_in.clone()); + let in_diff = max_in - min_in; + + let result = (min_out_weighted + max_out_weighted) / in_diff; + Ok(result) +} diff --git a/framework/base/src/math/weighted_average.rs b/framework/base/src/math/weighted_average.rs new file mode 100644 index 0000000000..87f54612c0 --- /dev/null +++ b/framework/base/src/math/weighted_average.rs @@ -0,0 +1,39 @@ +use core::ops::{Add, Div, Mul, Sub}; + +/// Computes the weighted average of two values. +/// +/// Returns `(first_value * first_weight + second_value * second_weight) / (first_weight + second_weight)`. +/// +/// # Panics +/// +/// Panics on division by zero if both weights are zero. +pub fn weighted_average(first_value: T, first_weight: T, second_value: T, second_weight: T) -> T +where + T: Add + Mul + Div + Clone, +{ + let weight_sum = first_weight.clone() + second_weight.clone(); + let weighted_sum = first_value * first_weight + second_value * second_weight; + weighted_sum / weight_sum +} + +/// Computes the weighted average of two values, rounded up (ceiling division). +/// +/// Equivalent to [`weighted_average`], but rounds the result up instead of truncating: +/// `(weighted_sum + weight_sum - 1) / weight_sum`. +/// +/// # Panics +/// +/// Panics on division by zero if both weights are zero. +pub fn weighted_average_round_up( + first_value: T, + first_weight: T, + second_value: T, + second_weight: T, +) -> T +where + T: Add + Sub + Mul + Div + Clone + From, +{ + let weight_sum = first_weight.clone() + second_weight.clone(); + let weighted_sum = first_value * first_weight + second_value * second_weight; + (weighted_sum + weight_sum.clone() - T::from(1u32)) / weight_sum +} diff --git a/framework/base/src/types.rs b/framework/base/src/types.rs index 05cd457c58..1548873988 100644 --- a/framework/base/src/types.rs +++ b/framework/base/src/types.rs @@ -3,7 +3,6 @@ pub mod heap; mod interaction; mod io; mod managed; -pub(crate) mod math_util; mod static_buffer; pub use crypto::*; diff --git a/framework/base/src/types/managed/wrapped/managed_decimal/managed_decimal_logarithm.rs b/framework/base/src/types/managed/wrapped/managed_decimal/managed_decimal_logarithm.rs index a363530af3..8cb56a6171 100644 --- a/framework/base/src/types/managed/wrapped/managed_decimal/managed_decimal_logarithm.rs +++ b/framework/base/src/types/managed/wrapped/managed_decimal/managed_decimal_logarithm.rs @@ -28,12 +28,12 @@ fn compute_ln( .unwrap_or_else(|| ErrorHelper::::signal_error_with_message("ln internal error")) as i64; - let mut result = crate::types::math_util::logarithm_i64::ln_polynomial(x); - crate::types::math_util::logarithm_i64::ln_add_bit_log2(&mut result, log2_floor); + let mut result = crate::math::internal_logarithm_i64::ln_polynomial(x); + crate::math::internal_logarithm_i64::ln_add_bit_log2(&mut result, log2_floor); debug_assert!(result > 0); - crate::types::math_util::logarithm_i64::ln_sub_decimals(&mut result, num_decimals); + crate::math::internal_logarithm_i64::ln_sub_decimals(&mut result, num_decimals); Some(ManagedDecimalSigned::from_raw_units( BigInt::from(result), @@ -60,12 +60,12 @@ fn compute_log2( .unwrap_or_else(|| ErrorHelper::::signal_error_with_message("log2 internal error")) as i64; - let mut result = crate::types::math_util::logarithm_i64::log2_polynomial(x); - crate::types::math_util::logarithm_i64::log2_add_bit_log2(&mut result, log2_floor); + let mut result = crate::math::internal_logarithm_i64::log2_polynomial(x); + crate::math::internal_logarithm_i64::log2_add_bit_log2(&mut result, log2_floor); debug_assert!(result > 0); - crate::types::math_util::logarithm_i64::log2_sub_decimals(&mut result, num_decimals); + crate::math::internal_logarithm_i64::log2_sub_decimals(&mut result, num_decimals); Some(ManagedDecimalSigned::from_raw_units( BigInt::from(result), diff --git a/framework/base/src/types/managed/wrapped/num/big_uint.rs b/framework/base/src/types/managed/wrapped/num/big_uint.rs index 35de36bf60..419e0feb3e 100644 --- a/framework/base/src/types/managed/wrapped/num/big_uint.rs +++ b/framework/base/src/types/managed/wrapped/num/big_uint.rs @@ -431,8 +431,8 @@ impl BigUint { .unwrap_or_else(|| ErrorHelper::::signal_error_with_message("ln internal error")) as i64; - let mut result = crate::types::math_util::logarithm_i64::ln_polynomial(x); - crate::types::math_util::logarithm_i64::ln_add_bit_log2(&mut result, log2_floor); + let mut result = crate::math::internal_logarithm_i64::ln_polynomial(x); + crate::math::internal_logarithm_i64::ln_add_bit_log2(&mut result, log2_floor); debug_assert!(result > 0); diff --git a/framework/base/src/types/math_util.rs b/framework/base/src/types/math_util.rs deleted file mode 100644 index 9322ee93ac..0000000000 --- a/framework/base/src/types/math_util.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod logarithm_i64; diff --git a/framework/base/tests/math_test.rs b/framework/base/tests/math_test.rs new file mode 100644 index 0000000000..6279aa2bb8 --- /dev/null +++ b/framework/base/tests/math_test.rs @@ -0,0 +1,117 @@ +use multiversx_sc::math::{ + LinearInterpolationInvalidValuesError, linear_interpolation, weighted_average, + weighted_average_round_up, +}; + +// ---- linear_interpolation ---- + +#[test] +fn linear_interpolation_at_min_input() { + // current_in == min_in => output == min_out + let result = linear_interpolation(0u32, 100u32, 0u32, 200u32, 400u32).unwrap(); + assert_eq!(result, 200u32); +} + +#[test] +fn linear_interpolation_at_max_input() { + // current_in == max_in => output == max_out + let result = linear_interpolation(0u32, 100u32, 100u32, 200u32, 400u32).unwrap(); + assert_eq!(result, 400u32); +} + +#[test] +fn linear_interpolation_at_midpoint() { + // current_in at the midpoint => output at midpoint of output range + let result = linear_interpolation(0u32, 100u32, 50u32, 0u32, 1000u32).unwrap(); + assert_eq!(result, 500u32); +} + +#[test] +fn linear_interpolation_at_one_quarter() { + // current_in at 25% => output at 25% of output range + let result = linear_interpolation(0u32, 100u32, 25u32, 0u32, 1000u32).unwrap(); + assert_eq!(result, 250u32); +} + +#[test] +fn linear_interpolation_non_zero_based_ranges() { + // Input range [10, 50], output range [100, 200], current_in = 30 (50% through input) + let result = linear_interpolation(10u32, 50u32, 30u32, 100u32, 200u32).unwrap(); + assert_eq!(result, 150u32); +} + +#[test] +fn linear_interpolation_reversed_output_range() { + // min_out > max_out is valid: output decreases as input increases + // Input range [0, 100], output range [1000, 0], current_in = 25 => output = 750 + let result = linear_interpolation(0u32, 100u32, 25u32, 1000u32, 0u32).unwrap(); + assert_eq!(result, 750u32); +} + +#[test] +fn linear_interpolation_below_range_returns_error() { + let result = linear_interpolation(10u32, 100u32, 5u32, 0u32, 1000u32); + assert!(matches!(result, Err(LinearInterpolationInvalidValuesError))); +} + +#[test] +fn linear_interpolation_above_range_returns_error() { + let result = linear_interpolation(0u32, 100u32, 110u32, 0u32, 1000u32); + assert!(matches!(result, Err(LinearInterpolationInvalidValuesError))); +} + +// ---- weighted_average ---- + +#[test] +fn weighted_average_equal_weights() { + // (10 * 1 + 20 * 1) / (1 + 1) = 15 + let result = weighted_average(10u64, 1u64, 20u64, 1u64); + assert_eq!(result, 15); +} + +#[test] +fn weighted_average_all_weight_on_first() { + // second_weight = 0 => result == first_value + let result = weighted_average(10u64, 5u64, 99u64, 0u64); + assert_eq!(result, 10); +} + +#[test] +fn weighted_average_all_weight_on_second() { + // first_weight = 0 => result == second_value + let result = weighted_average(99u64, 0u64, 20u64, 5u64); + assert_eq!(result, 20); +} + +#[test] +fn weighted_average_three_to_one() { + // (0 * 1 + 60 * 3) / (1 + 3) = 180 / 4 = 45 + let result = weighted_average(0u64, 1u64, 60u64, 3u64); + assert_eq!(result, 45); +} + +// ---- weighted_average_round_up ---- + +#[test] +fn weighted_average_round_up_exact_division() { + // (10 * 1 + 20 * 1) / (1 + 1) = 15, no rounding needed + let result = weighted_average_round_up(10u64, 1u64, 20u64, 1u64); + assert_eq!(result, 15); +} + +#[test] +fn weighted_average_round_up_truncates_vs_rounds() { + // floor: (0 * 1 + 10 * 3) / (1 + 3) = 30 / 4 = 7 + // ceil: (30 + 4 - 1) / 4 = 33 / 4 = 8 + let floor_result = weighted_average(0u64, 1u64, 10u64, 3u64); + let ceil_result = weighted_average_round_up(0u64, 1u64, 10u64, 3u64); + assert_eq!(floor_result, 7); + assert_eq!(ceil_result, 8); +} + +#[test] +fn weighted_average_round_up_no_change_when_exact() { + // (0 * 1 + 20 * 3) / (1 + 3) = 60 / 4 = 15 exactly + let result = weighted_average_round_up(0u64, 1u64, 20u64, 3u64); + assert_eq!(result, 15); +} diff --git a/framework/scenario/tests/math_managed_test.rs b/framework/scenario/tests/math_managed_test.rs new file mode 100644 index 0000000000..ef0aa7b449 --- /dev/null +++ b/framework/scenario/tests/math_managed_test.rs @@ -0,0 +1,119 @@ +use multiversx_sc::math::{ + LinearInterpolationInvalidValuesError, linear_interpolation, weighted_average, + weighted_average_round_up, +}; +use multiversx_sc::types::{BigUint, ManagedDecimal, NumDecimals}; +use multiversx_sc_scenario::api::StaticApi; + +fn md(v: u64) -> ManagedDecimal { + ManagedDecimal::from_raw_units(BigUint::from(v), 4usize) +} + +fn bu(v: u64) -> BigUint { + BigUint::from(v) +} + +// ---- linear_interpolation ---- + +#[test] +fn linear_interpolation_at_min_input() { + // current_in == min_in => output == min_out + let result = linear_interpolation(md(0), md(100), md(0), md(200), md(400)).unwrap(); + assert_eq!(result, md(200)); +} + +#[test] +fn linear_interpolation_at_max_input() { + // current_in == max_in => output == max_out + let result = linear_interpolation(md(0), md(100), md(100), md(200), md(400)).unwrap(); + assert_eq!(result, md(400)); +} + +#[test] +fn linear_interpolation_at_midpoint() { + // current_in at the midpoint => output at midpoint of output range + let result = linear_interpolation(md(0), md(100), md(50), md(0), md(1000)).unwrap(); + assert_eq!(result, md(500)); +} + +#[test] +fn linear_interpolation_at_one_quarter() { + // current_in at 25% => output at 25% of output range + let result = linear_interpolation(md(0), md(100), md(25), md(0), md(1000)).unwrap(); + assert_eq!(result, md(250)); +} + +#[test] +fn linear_interpolation_non_zero_based_ranges() { + // Input range [10, 50], output range [100, 200], current_in = 30 (50% through input) + let result = linear_interpolation(md(10), md(50), md(30), md(100), md(200)).unwrap(); + assert_eq!(result, md(150)); +} + +#[test] +fn linear_interpolation_below_range_returns_error() { + let result = linear_interpolation(md(10), md(100), md(5), md(0), md(1000)); + assert!(matches!(result, Err(LinearInterpolationInvalidValuesError))); +} + +#[test] +fn linear_interpolation_above_range_returns_error() { + let result = linear_interpolation(md(0), md(100), md(110), md(0), md(1000)); + assert!(matches!(result, Err(LinearInterpolationInvalidValuesError))); +} + +// ---- weighted_average ---- + +#[test] +fn weighted_average_equal_weights() { + // (10 * 1 + 20 * 1) / (1 + 1) = 15 + let result = weighted_average(bu(10), bu(1), bu(20), bu(1)); + assert_eq!(result, bu(15)); +} + +#[test] +fn weighted_average_all_weight_on_first() { + // second_weight = 0 => result == first_value + let result = weighted_average(bu(10), bu(5), bu(99), bu(0)); + assert_eq!(result, bu(10)); +} + +#[test] +fn weighted_average_all_weight_on_second() { + // first_weight = 0 => result == second_value + let result = weighted_average(bu(99), bu(0), bu(20), bu(5)); + assert_eq!(result, bu(20)); +} + +#[test] +fn weighted_average_three_to_one() { + // (0 * 1 + 60 * 3) / (1 + 3) = 180 / 4 = 45 + let result = weighted_average(bu(0), bu(1), bu(60), bu(3)); + assert_eq!(result, bu(45)); +} + +// ---- weighted_average_round_up ---- + +#[test] +fn weighted_average_round_up_exact_division() { + // (10 * 1 + 20 * 1) / (1 + 1) = 15, no rounding needed + let result = weighted_average_round_up(bu(10), bu(1), bu(20), bu(1)); + assert_eq!(result, bu(15)); +} + +#[test] +fn weighted_average_round_up_truncates_vs_rounds() { + // floor: (0 * 1 + 10 * 3) / (1 + 3) = 30 / 4 = 7 + // ceil: (30 + 4 - 1) / 4 = 33 / 4 = 8 + let floor_result = weighted_average(bu(0), bu(1), bu(10), bu(3)); + let ceil_result = weighted_average_round_up(bu(0), bu(1), bu(10), bu(3)); + assert_eq!(floor_result, bu(7)); + assert_eq!(ceil_result, bu(8)); +} + +#[test] +fn weighted_average_round_up_no_change_when_exact() { + // (0 * 1 + 20 * 3) / (1 + 3) = 60 / 4 = 15 exactly + let result = weighted_average_round_up(bu(0), bu(1), bu(20), bu(3)); + assert_eq!(result, bu(15)); +}