Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions framework/base/src/err_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub const VALUE_EXCEEDS_SLICE: &str = "value exceeds target slice";
pub const CAST_TO_I64_ERROR: &str = "cast to i64 error";
pub const BIG_UINT_EXCEEDS_SLICE: &str = "big uint as_bytes exceed target slice";
pub const BIG_UINT_SUB_NEGATIVE: &str = "cannot subtract because result would be negative";
pub const BIG_UINT_NTH_ROOT_ZERO: &str = "cannot compute 0th root";
pub const UNSIGNED_NEGATIVE: &str = "cannot convert to unsigned, number is negative";
pub const ZERO_VALUE_NOT_ALLOWED: &str = "zero value not allowed";
pub const PROPORTION_OVERFLOW_ERR: &str = "proportion overflow";
Expand Down
4 changes: 4 additions & 0 deletions framework/base/src/types/managed/basic/big_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ impl<M: ManagedTypeApi> Clone for BigInt<M> {
result
}
}

fn clone_from(&mut self, source: &Self) {
BigInt::<M>::clone_to_handle(source.get_handle(), self.get_handle());
}
}

impl<M: ManagedTypeApi> Drop for BigInt<M> {
Expand Down
36 changes: 35 additions & 1 deletion framework/base/src/types/managed/wrapped/managed_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ pub use managed_decimal_signed::ManagedDecimalSigned;

use crate::{
abi::{TypeAbi, TypeAbiFrom, TypeName},
api::{ManagedTypeApi, ManagedTypeApiImpl},
api::{ManagedTypeApi, ManagedTypeApiImpl, quick_signal_error},
err_msg,
formatter::{FormatBuffer, FormatByteReceiver, SCDisplay},
typenum::{U4, U8, Unsigned},
types::BigUint,
Expand Down Expand Up @@ -139,6 +140,39 @@ impl<M: ManagedTypeApi, DECIMALS: Unsigned> From<ManagedDecimal<M, ConstDecimals
}
}

impl<M: ManagedTypeApi, D: Decimals + Clone> ManagedDecimal<M, D> {
/// Integer part of the k-th root, preserving the decimal scale.
///
/// Internally pre-scales the raw data by `scaling_factor^(k-1)` so that after
/// taking the integer root the decimal point lands in the correct position:
///
/// ```text
/// self.data = v * 10^d
/// → scaled = self.data * (10^d)^(k-1) = v * 10^(d*k)
/// → root = floor(scaled^(1/k)) = floor(v^(1/k) * 10^d)
/// ```
///
/// Returns `0` (with the same scale) when `self` is zero.
///
/// # Panics
/// Panics if `k` is zero.
pub fn nth_root(&self, k: u32) -> Self {
if k == 0 {
quick_signal_error::<M>(err_msg::BIG_UINT_NTH_ROOT_ZERO);
}

if k == 1 {
return self.clone();
}

let sf = self.decimals.scaling_factor::<M>();
// Multiply by sf^(k-1) before rooting so the decimal position is preserved.
// For k==0, the check in BigUint::nth_root handles the error signal.
let scaled = &self.data * &sf.pow(k.saturating_sub(1));
ManagedDecimal::from_raw_units(scaled.nth_root_unchecked(k), self.decimals.clone())
}
}

impl<M: ManagedTypeApi> ManagedVecItem for ManagedDecimal<M, NumDecimals> {
type PAYLOAD = ManagedVecItemPayloadBuffer<U8>; // 4 bigUint + 4 usize

Expand Down
91 changes: 88 additions & 3 deletions framework/base/src/types/managed/wrapped/num/big_uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,96 @@ impl<M: ManagedTypeApi> BigUint<M> {
}
}

/// Assigns `self = base^exp`.
pub fn pow_assign(&mut self, base: &BigUint<M>, exp: u32) {
let exp_handle = BigUint::<M>::make_temp(const_handles::BIG_INT_TEMPORARY_1, exp);
M::managed_type_impl().bi_pow(self.get_handle(), base.get_handle(), exp_handle);
}

pub fn pow(&self, exp: u32) -> Self {
let big_int_temp_1 = BigUint::<M>::make_temp(const_handles::BIG_INT_TEMPORARY_1, exp);
unsafe {
let result = BigUint::new_uninit();
M::managed_type_impl().bi_pow(result.get_handle(), self.get_handle(), big_int_temp_1);
let mut result = BigUint::new_uninit();
result.pow_assign(self, exp);
result
}
}

/// The integer part of the k-th root, computed via Newton's method.
///
/// The initial guess is derived from the number of significant bits (`log2_floor`):
/// `x0 = 2^(floor(log2(self) / k) + 1)`, which is always an overestimate.
///
/// Returns `0` when `self` is zero.
///
/// # Panics
/// Panics if `k` is zero.
pub fn nth_root(&self, k: u32) -> Self {
if k == 0 {
quick_signal_error::<M>(err_msg::BIG_UINT_NTH_ROOT_ZERO);
}

if k == 1 {
return self.clone();
}

self.nth_root_unchecked(k)
}

// Expects k > 1. Does not check this precondition, so it is the caller's responsibility to ensure it.
pub(crate) fn nth_root_unchecked(&self, k: u32) -> Self {
// log2 is None for the number zero,
// but in this case we can return early with the correct result of zero without doing any computation
let Some(log2) = self.log2_floor() else {
return BigUint::zero();
};

// Initial overestimate: 2^(floor(log2 / k) + 1)
let mut x = BigUint::from(1u64) << ((log2 / k + 1) as usize);

// Newton's iteration: x = ((k-1)*x + self / x^(k-1)) / k
// Converges from above; stop when the estimate stops decreasing.
let k_big = BigUint::<M>::from(k as u64);
let k_minus_1_big = BigUint::<M>::from((k - 1) as u64);

// Pre-allocate buffers reused across iterations to avoid per-iteration allocations.
// SAFETY: both are fully written before being read in every iteration.
let mut x_pow_k_minus_1 = unsafe { BigUint::new_uninit() };
let mut new_x = unsafe { BigUint::new_uninit() };
let api = M::managed_type_impl();
loop {
// x_pow_k_minus_1 = x^(k-1)
x_pow_k_minus_1.pow_assign(&x, k - 1);

// Reuse x_pow_k_minus_1's handle for self / x^(k-1).
// The VM reads both operands before writing, so dest == divisor is safe.
api.bi_t_div(
x_pow_k_minus_1.get_handle(),
self.get_handle(),
x_pow_k_minus_1.get_handle(),
);

// new_x = (k-1)*x + self/x^(k-1)
api.bi_mul(
new_x.get_handle(),
k_minus_1_big.get_handle(),
x.get_handle(),
);
new_x += &x_pow_k_minus_1;

// new_x /= k
new_x /= &k_big;

if new_x >= x {
break;
}

// Swap handles instead of cloning: zero API calls, no allocation.
core::mem::swap(&mut x, &mut new_x);
}

x
}

/// The whole part of the base-2 logarithm.
///
/// Obtained by counting the significant bits.
Expand Down Expand Up @@ -447,6 +528,10 @@ impl<M: ManagedTypeApi> Clone for BigUint<M> {
fn clone(&self) -> Self {
unsafe { self.as_big_int().clone().into_big_uint_unchecked() }
}

fn clone_from(&mut self, source: &Self) {
self.value.clone_from(&source.value);
}
}

impl<M: ManagedTypeApi> TryStaticCast for BigUint<M> {}
Expand Down
59 changes: 59 additions & 0 deletions framework/scenario/tests/big_uint_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,62 @@ fn test_big_uint_saturating_sub_assign() {
fn test_big_uint_proportion_overflow() {
let _ = BigUint::<StaticApi>::from(100u64).proportion(100, i64::MAX as u64 + 1);
}

fn assert_nth_root(x: u64, k: u32, expected: u64) {
let big = BigUint::<StaticApi>::from(x);
let result = big.nth_root(k);
let expected_big = BigUint::<StaticApi>::from(expected);
assert_eq!(
result, expected_big,
"nth_root({x}, {k}) expected {expected}"
);
}

#[test]
fn test_big_uint_nth_root() {
// k = 1: identity
assert_nth_root(0, 1, 0);
assert_nth_root(1, 1, 1);
assert_nth_root(42, 1, 42);

// zero base: always 0
assert_nth_root(0, 2, 0);
assert_nth_root(0, 3, 0);
assert_nth_root(0, 100, 0);

// perfect squares (agreeing with sqrt)
assert_nth_root(1, 2, 1);
assert_nth_root(4, 2, 2);
assert_nth_root(9, 2, 3);
assert_nth_root(100, 2, 10);
assert_nth_root(10000, 2, 100);

// perfect cubes
assert_nth_root(1, 3, 1);
assert_nth_root(8, 3, 2);
assert_nth_root(27, 3, 3);
assert_nth_root(125, 3, 5);
assert_nth_root(1000, 3, 10);

// floor (not an exact power)
assert_nth_root(2, 2, 1); // sqrt(2) ~ 1.41
assert_nth_root(10, 3, 2); // cbrt(10) ~ 2.154
assert_nth_root(100, 3, 4); // cbrt(100) ~ 4.641
assert_nth_root(255, 2, 15); // sqrt(255) ~ 15.96
assert_nth_root(1023, 10, 1); // 1023^(1/10) ~ 1.995

// higher roots
assert_nth_root(16, 4, 2); // 2^4 = 16
assert_nth_root(32, 5, 2); // 2^5 = 32
assert_nth_root(1024, 10, 2); // 2^10 = 1024
assert_nth_root(2_u64.pow(20), 20, 2); // 2^20

// large number
assert_nth_root(1_000_000_000, 3, 1000); // 1000^3 = 10^9
}

#[test]
#[should_panic = "StaticApi signal error: cannot compute 0th root"]
fn test_big_uint_nth_root_zero_k() {
let _ = BigUint::<StaticApi>::from(10u64).nth_root(0);
}
51 changes: 51 additions & 0 deletions framework/scenario/tests/managed_decimal_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,3 +660,54 @@ pub fn test_managed_decimal_div_mix_decimals_type_reverse() {

assert_eq!(result, expected)
}

// d=4 ManagedDecimal helper: raw units / 10^4
fn md4(v: u64) -> ManagedDecimal<StaticApi, NumDecimals> {
ManagedDecimal::from_raw_units(BigUint::from(v), 4usize)
}

fn assert_md4_nth_root(raw: u64, k: u32, expected_raw: u64) {
assert_eq!(
md4(raw).nth_root(k).into_raw_units(),
&BigUint::<StaticApi>::from(expected_raw)
);
}

#[test]
fn test_managed_decimal_nth_root() {
// k=1: identity
assert_md4_nth_root(40000, 1, 40000);

// zero: any k≥2 root of 0.0000 is 0.0000
assert_md4_nth_root(0, 2, 0);

// sqrt(4.0000) = 2.0000
// scaled = 40000 * 10000^1 = 400_000_000, sqrt = 20000
assert_md4_nth_root(40000, 2, 20000);

// sqrt(9.0000) = 3.0000
// scaled = 90000 * 10000 = 900_000_000, sqrt = 30000
assert_md4_nth_root(90000, 2, 30000);

// cbrt(8.0000) = 2.0000
// scaled = 80000 * 10000^2 = 8_000_000_000_000, cbrt = 20000
assert_md4_nth_root(80000, 3, 20000);

// cbrt(27.0000) = 3.0000
// scaled = 270000 * 10000^2 = 27_000_000_000_000, cbrt = 30000
assert_md4_nth_root(270000, 3, 30000);

// sqrt(2.0000) ≈ 1.4142 (floor)
// scaled = 20000 * 10000 = 200_000_000, sqrt = 14142
assert_md4_nth_root(20000, 2, 14142);

// cbrt(2.0000) ≈ 1.2599 (floor)
// scaled = 20000 * 10000^2 = 2_000_000_000_000, cbrt = 12599
assert_md4_nth_root(20000, 3, 12599);
}

#[test]
#[should_panic = "StaticApi signal error: cannot compute 0th root"]
fn test_managed_decimal_nth_root_zero_k() {
let _ = md4(40000).nth_root(0);
}
Loading