Skip to content

Commit 0df3c96

Browse files
nth root - ManagedDecimal
1 parent 6cb7cac commit 0df3c96

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

framework/base/src/types/managed/wrapped/managed_decimal.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,35 @@ impl<M: ManagedTypeApi, DECIMALS: Unsigned> From<ManagedDecimal<M, ConstDecimals
139139
}
140140
}
141141

142+
impl<M: ManagedTypeApi, D: Decimals + Clone> ManagedDecimal<M, D> {
143+
/// Integer part of the k-th root, preserving the decimal scale.
144+
///
145+
/// Internally pre-scales the raw data by `scaling_factor^(k-1)` so that after
146+
/// taking the integer root the decimal point lands in the correct position:
147+
///
148+
/// ```text
149+
/// self.data = v * 10^d
150+
/// → scaled = self.data * (10^d)^(k-1) = v * 10^(d*k)
151+
/// → root = floor(scaled^(1/k)) = floor(v^(1/k) * 10^d)
152+
/// ```
153+
///
154+
/// Returns `0` (with the same scale) when `self` is zero.
155+
///
156+
/// # Panics
157+
/// Panics if `k` is zero.
158+
pub fn nth_root(&self, k: u32) -> Self {
159+
if k == 1 {
160+
return self.clone();
161+
}
162+
163+
let sf = self.decimals.scaling_factor::<M>();
164+
// Multiply by sf^(k-1) before rooting so the decimal position is preserved.
165+
// For k==0, the check in BigUint::nth_root handles the error signal.
166+
let scaled = &self.data * &sf.pow(k.saturating_sub(1));
167+
ManagedDecimal::from_raw_units(scaled.nth_root(k), self.decimals.clone())
168+
}
169+
}
170+
142171
impl<M: ManagedTypeApi> ManagedVecItem for ManagedDecimal<M, NumDecimals> {
143172
type PAYLOAD = ManagedVecItemPayloadBuffer<U8>; // 4 bigUint + 4 usize
144173

framework/base/src/types/managed/wrapped/num/big_uint.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ impl<M: ManagedTypeApi> BigUint<M> {
396396
);
397397

398398
// new_x = (k-1)*x + self/x^(k-1)
399-
api.bi_mul(new_x.get_handle(), k_minus_1_big.get_handle(), x.get_handle());
399+
api.bi_mul(
400+
new_x.get_handle(),
401+
k_minus_1_big.get_handle(),
402+
x.get_handle(),
403+
);
400404
new_x += &x_pow_k_minus_1;
401405

402406
// new_x /= k
@@ -405,7 +409,7 @@ impl<M: ManagedTypeApi> BigUint<M> {
405409
if new_x >= x {
406410
break;
407411
}
408-
412+
409413
// Swap handles instead of cloning: zero API calls, no allocation.
410414
core::mem::swap(&mut x, &mut new_x);
411415
}

framework/scenario/tests/managed_decimal_test.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,3 +660,54 @@ pub fn test_managed_decimal_div_mix_decimals_type_reverse() {
660660

661661
assert_eq!(result, expected)
662662
}
663+
664+
// d=4 ManagedDecimal helper: raw units / 10^4
665+
fn md4(v: u64) -> ManagedDecimal<StaticApi, NumDecimals> {
666+
ManagedDecimal::from_raw_units(BigUint::from(v), 4usize)
667+
}
668+
669+
fn assert_md4_nth_root(raw: u64, k: u32, expected_raw: u64) {
670+
assert_eq!(
671+
md4(raw).nth_root(k).into_raw_units(),
672+
&BigUint::<StaticApi>::from(expected_raw)
673+
);
674+
}
675+
676+
#[test]
677+
fn test_managed_decimal_nth_root() {
678+
// k=1: identity
679+
assert_md4_nth_root(40000, 1, 40000);
680+
681+
// zero: any k≥2 root of 0.0000 is 0.0000
682+
assert_md4_nth_root(0, 2, 0);
683+
684+
// sqrt(4.0000) = 2.0000
685+
// scaled = 40000 * 10000^1 = 400_000_000, sqrt = 20000
686+
assert_md4_nth_root(40000, 2, 20000);
687+
688+
// sqrt(9.0000) = 3.0000
689+
// scaled = 90000 * 10000 = 900_000_000, sqrt = 30000
690+
assert_md4_nth_root(90000, 2, 30000);
691+
692+
// cbrt(8.0000) = 2.0000
693+
// scaled = 80000 * 10000^2 = 8_000_000_000_000, cbrt = 20000
694+
assert_md4_nth_root(80000, 3, 20000);
695+
696+
// cbrt(27.0000) = 3.0000
697+
// scaled = 270000 * 10000^2 = 27_000_000_000_000, cbrt = 30000
698+
assert_md4_nth_root(270000, 3, 30000);
699+
700+
// sqrt(2.0000) ≈ 1.4142 (floor)
701+
// scaled = 20000 * 10000 = 200_000_000, sqrt = 14142
702+
assert_md4_nth_root(20000, 2, 14142);
703+
704+
// cbrt(2.0000) ≈ 1.2599 (floor)
705+
// scaled = 20000 * 10000^2 = 2_000_000_000_000, cbrt = 12599
706+
assert_md4_nth_root(20000, 3, 12599);
707+
}
708+
709+
#[test]
710+
#[should_panic = "StaticApi signal error: cannot compute 0th root"]
711+
fn test_managed_decimal_nth_root_zero_k() {
712+
let _ = md4(40000).nth_root(0);
713+
}

0 commit comments

Comments
 (0)