Skip to content

Commit bfed3c6

Browse files
Merge pull request #2335 from multiversx/nth-root
Nth root
2 parents 84ef63a + 8b6d1ca commit bfed3c6

File tree

6 files changed

+238
-4
lines changed

6 files changed

+238
-4
lines changed

framework/base/src/err_msg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub const VALUE_EXCEEDS_SLICE: &str = "value exceeds target slice";
4646
pub const CAST_TO_I64_ERROR: &str = "cast to i64 error";
4747
pub const BIG_UINT_EXCEEDS_SLICE: &str = "big uint as_bytes exceed target slice";
4848
pub const BIG_UINT_SUB_NEGATIVE: &str = "cannot subtract because result would be negative";
49+
pub const BIG_UINT_NTH_ROOT_ZERO: &str = "cannot compute 0th root";
4950
pub const UNSIGNED_NEGATIVE: &str = "cannot convert to unsigned, number is negative";
5051
pub const ZERO_VALUE_NOT_ALLOWED: &str = "zero value not allowed";
5152
pub const PROPORTION_OVERFLOW_ERR: &str = "proportion overflow";

framework/base/src/types/managed/basic/big_int.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ impl<M: ManagedTypeApi> Clone for BigInt<M> {
282282
result
283283
}
284284
}
285+
286+
fn clone_from(&mut self, source: &Self) {
287+
BigInt::<M>::clone_to_handle(source.get_handle(), self.get_handle());
288+
}
285289
}
286290

287291
impl<M: ManagedTypeApi> Drop for BigInt<M> {

framework/base/src/types/managed/wrapped/managed_decimal.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ pub use managed_decimal_signed::ManagedDecimalSigned;
1717

1818
use crate::{
1919
abi::{TypeAbi, TypeAbiFrom, TypeName},
20-
api::{ManagedTypeApi, ManagedTypeApiImpl},
20+
api::{ManagedTypeApi, ManagedTypeApiImpl, quick_signal_error},
21+
err_msg,
2122
formatter::{FormatBuffer, FormatByteReceiver, SCDisplay},
2223
typenum::{U4, U8, Unsigned},
2324
types::BigUint,
@@ -139,6 +140,39 @@ impl<M: ManagedTypeApi, DECIMALS: Unsigned> From<ManagedDecimal<M, ConstDecimals
139140
}
140141
}
141142

143+
impl<M: ManagedTypeApi, D: Decimals + Clone> ManagedDecimal<M, D> {
144+
/// Integer part of the k-th root, preserving the decimal scale.
145+
///
146+
/// Internally pre-scales the raw data by `scaling_factor^(k-1)` so that after
147+
/// taking the integer root the decimal point lands in the correct position:
148+
///
149+
/// ```text
150+
/// self.data = v * 10^d
151+
/// → scaled = self.data * (10^d)^(k-1) = v * 10^(d*k)
152+
/// → root = floor(scaled^(1/k)) = floor(v^(1/k) * 10^d)
153+
/// ```
154+
///
155+
/// Returns `0` (with the same scale) when `self` is zero.
156+
///
157+
/// # Panics
158+
/// Panics if `k` is zero.
159+
pub fn nth_root(&self, k: u32) -> Self {
160+
if k == 0 {
161+
quick_signal_error::<M>(err_msg::BIG_UINT_NTH_ROOT_ZERO);
162+
}
163+
164+
if k == 1 {
165+
return self.clone();
166+
}
167+
168+
let sf = self.decimals.scaling_factor::<M>();
169+
// Multiply by sf^(k-1) before rooting so the decimal position is preserved.
170+
// For k==0, the check in BigUint::nth_root handles the error signal.
171+
let scaled = &self.data * &sf.pow(k.saturating_sub(1));
172+
ManagedDecimal::from_raw_units(scaled.nth_root_unchecked(k), self.decimals.clone())
173+
}
174+
}
175+
142176
impl<M: ManagedTypeApi> ManagedVecItem for ManagedDecimal<M, NumDecimals> {
143177
type PAYLOAD = ManagedVecItemPayloadBuffer<U8>; // 4 bigUint + 4 usize
144178

framework/base/src/types/managed/wrapped/num/big_uint.rs

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,96 @@ impl<M: ManagedTypeApi> BigUint<M> {
332332
}
333333
}
334334

335+
/// Assigns `self = base^exp`.
336+
pub fn pow_assign(&mut self, base: &BigUint<M>, exp: u32) {
337+
let exp_handle = BigUint::<M>::make_temp(const_handles::BIG_INT_TEMPORARY_1, exp);
338+
M::managed_type_impl().bi_pow(self.get_handle(), base.get_handle(), exp_handle);
339+
}
340+
335341
pub fn pow(&self, exp: u32) -> Self {
336-
let big_int_temp_1 = BigUint::<M>::make_temp(const_handles::BIG_INT_TEMPORARY_1, exp);
337342
unsafe {
338-
let result = BigUint::new_uninit();
339-
M::managed_type_impl().bi_pow(result.get_handle(), self.get_handle(), big_int_temp_1);
343+
let mut result = BigUint::new_uninit();
344+
result.pow_assign(self, exp);
340345
result
341346
}
342347
}
343348

349+
/// The integer part of the k-th root, computed via Newton's method.
350+
///
351+
/// The initial guess is derived from the number of significant bits (`log2_floor`):
352+
/// `x0 = 2^(floor(log2(self) / k) + 1)`, which is always an overestimate.
353+
///
354+
/// Returns `0` when `self` is zero.
355+
///
356+
/// # Panics
357+
/// Panics if `k` is zero.
358+
pub fn nth_root(&self, k: u32) -> Self {
359+
if k == 0 {
360+
quick_signal_error::<M>(err_msg::BIG_UINT_NTH_ROOT_ZERO);
361+
}
362+
363+
if k == 1 {
364+
return self.clone();
365+
}
366+
367+
self.nth_root_unchecked(k)
368+
}
369+
370+
// Expects k > 1. Does not check this precondition, so it is the caller's responsibility to ensure it.
371+
pub(crate) fn nth_root_unchecked(&self, k: u32) -> Self {
372+
// log2 is None for the number zero,
373+
// but in this case we can return early with the correct result of zero without doing any computation
374+
let Some(log2) = self.log2_floor() else {
375+
return BigUint::zero();
376+
};
377+
378+
// Initial overestimate: 2^(floor(log2 / k) + 1)
379+
let mut x = BigUint::from(1u64) << ((log2 / k + 1) as usize);
380+
381+
// Newton's iteration: x = ((k-1)*x + self / x^(k-1)) / k
382+
// Converges from above; stop when the estimate stops decreasing.
383+
let k_big = BigUint::<M>::from(k as u64);
384+
let k_minus_1_big = BigUint::<M>::from((k - 1) as u64);
385+
386+
// Pre-allocate buffers reused across iterations to avoid per-iteration allocations.
387+
// SAFETY: both are fully written before being read in every iteration.
388+
let mut x_pow_k_minus_1 = unsafe { BigUint::new_uninit() };
389+
let mut new_x = unsafe { BigUint::new_uninit() };
390+
let api = M::managed_type_impl();
391+
loop {
392+
// x_pow_k_minus_1 = x^(k-1)
393+
x_pow_k_minus_1.pow_assign(&x, k - 1);
394+
395+
// Reuse x_pow_k_minus_1's handle for self / x^(k-1).
396+
// The VM reads both operands before writing, so dest == divisor is safe.
397+
api.bi_t_div(
398+
x_pow_k_minus_1.get_handle(),
399+
self.get_handle(),
400+
x_pow_k_minus_1.get_handle(),
401+
);
402+
403+
// new_x = (k-1)*x + self/x^(k-1)
404+
api.bi_mul(
405+
new_x.get_handle(),
406+
k_minus_1_big.get_handle(),
407+
x.get_handle(),
408+
);
409+
new_x += &x_pow_k_minus_1;
410+
411+
// new_x /= k
412+
new_x /= &k_big;
413+
414+
if new_x >= x {
415+
break;
416+
}
417+
418+
// Swap handles instead of cloning: zero API calls, no allocation.
419+
core::mem::swap(&mut x, &mut new_x);
420+
}
421+
422+
x
423+
}
424+
344425
/// The whole part of the base-2 logarithm.
345426
///
346427
/// Obtained by counting the significant bits.
@@ -447,6 +528,10 @@ impl<M: ManagedTypeApi> Clone for BigUint<M> {
447528
fn clone(&self) -> Self {
448529
unsafe { self.as_big_int().clone().into_big_uint_unchecked() }
449530
}
531+
532+
fn clone_from(&mut self, source: &Self) {
533+
self.value.clone_from(&source.value);
534+
}
450535
}
451536

452537
impl<M: ManagedTypeApi> TryStaticCast for BigUint<M> {}

framework/scenario/tests/big_uint_test.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,62 @@ fn test_big_uint_saturating_sub_assign() {
116116
fn test_big_uint_proportion_overflow() {
117117
let _ = BigUint::<StaticApi>::from(100u64).proportion(100, i64::MAX as u64 + 1);
118118
}
119+
120+
fn assert_nth_root(x: u64, k: u32, expected: u64) {
121+
let big = BigUint::<StaticApi>::from(x);
122+
let result = big.nth_root(k);
123+
let expected_big = BigUint::<StaticApi>::from(expected);
124+
assert_eq!(
125+
result, expected_big,
126+
"nth_root({x}, {k}) expected {expected}"
127+
);
128+
}
129+
130+
#[test]
131+
fn test_big_uint_nth_root() {
132+
// k = 1: identity
133+
assert_nth_root(0, 1, 0);
134+
assert_nth_root(1, 1, 1);
135+
assert_nth_root(42, 1, 42);
136+
137+
// zero base: always 0
138+
assert_nth_root(0, 2, 0);
139+
assert_nth_root(0, 3, 0);
140+
assert_nth_root(0, 100, 0);
141+
142+
// perfect squares (agreeing with sqrt)
143+
assert_nth_root(1, 2, 1);
144+
assert_nth_root(4, 2, 2);
145+
assert_nth_root(9, 2, 3);
146+
assert_nth_root(100, 2, 10);
147+
assert_nth_root(10000, 2, 100);
148+
149+
// perfect cubes
150+
assert_nth_root(1, 3, 1);
151+
assert_nth_root(8, 3, 2);
152+
assert_nth_root(27, 3, 3);
153+
assert_nth_root(125, 3, 5);
154+
assert_nth_root(1000, 3, 10);
155+
156+
// floor (not an exact power)
157+
assert_nth_root(2, 2, 1); // sqrt(2) ~ 1.41
158+
assert_nth_root(10, 3, 2); // cbrt(10) ~ 2.154
159+
assert_nth_root(100, 3, 4); // cbrt(100) ~ 4.641
160+
assert_nth_root(255, 2, 15); // sqrt(255) ~ 15.96
161+
assert_nth_root(1023, 10, 1); // 1023^(1/10) ~ 1.995
162+
163+
// higher roots
164+
assert_nth_root(16, 4, 2); // 2^4 = 16
165+
assert_nth_root(32, 5, 2); // 2^5 = 32
166+
assert_nth_root(1024, 10, 2); // 2^10 = 1024
167+
assert_nth_root(2_u64.pow(20), 20, 2); // 2^20
168+
169+
// large number
170+
assert_nth_root(1_000_000_000, 3, 1000); // 1000^3 = 10^9
171+
}
172+
173+
#[test]
174+
#[should_panic = "StaticApi signal error: cannot compute 0th root"]
175+
fn test_big_uint_nth_root_zero_k() {
176+
let _ = BigUint::<StaticApi>::from(10u64).nth_root(0);
177+
}

framework/scenario/tests/managed_decimal_test.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,3 +660,54 @@ pub fn test_managed_decimal_div_mix_decimals_type_reverse() {
660660

661661
assert_eq!(result, expected)
662662
}
663+
664+
// d=4 ManagedDecimal helper: raw units / 10^4
665+
fn md4(v: u64) -> ManagedDecimal<StaticApi, NumDecimals> {
666+
ManagedDecimal::from_raw_units(BigUint::from(v), 4usize)
667+
}
668+
669+
fn assert_md4_nth_root(raw: u64, k: u32, expected_raw: u64) {
670+
assert_eq!(
671+
md4(raw).nth_root(k).into_raw_units(),
672+
&BigUint::<StaticApi>::from(expected_raw)
673+
);
674+
}
675+
676+
#[test]
677+
fn test_managed_decimal_nth_root() {
678+
// k=1: identity
679+
assert_md4_nth_root(40000, 1, 40000);
680+
681+
// zero: any k≥2 root of 0.0000 is 0.0000
682+
assert_md4_nth_root(0, 2, 0);
683+
684+
// sqrt(4.0000) = 2.0000
685+
// scaled = 40000 * 10000^1 = 400_000_000, sqrt = 20000
686+
assert_md4_nth_root(40000, 2, 20000);
687+
688+
// sqrt(9.0000) = 3.0000
689+
// scaled = 90000 * 10000 = 900_000_000, sqrt = 30000
690+
assert_md4_nth_root(90000, 2, 30000);
691+
692+
// cbrt(8.0000) = 2.0000
693+
// scaled = 80000 * 10000^2 = 8_000_000_000_000, cbrt = 20000
694+
assert_md4_nth_root(80000, 3, 20000);
695+
696+
// cbrt(27.0000) = 3.0000
697+
// scaled = 270000 * 10000^2 = 27_000_000_000_000, cbrt = 30000
698+
assert_md4_nth_root(270000, 3, 30000);
699+
700+
// sqrt(2.0000) ≈ 1.4142 (floor)
701+
// scaled = 20000 * 10000 = 200_000_000, sqrt = 14142
702+
assert_md4_nth_root(20000, 2, 14142);
703+
704+
// cbrt(2.0000) ≈ 1.2599 (floor)
705+
// scaled = 20000 * 10000^2 = 2_000_000_000_000, cbrt = 12599
706+
assert_md4_nth_root(20000, 3, 12599);
707+
}
708+
709+
#[test]
710+
#[should_panic = "StaticApi signal error: cannot compute 0th root"]
711+
fn test_managed_decimal_nth_root_zero_k() {
712+
let _ = md4(40000).nth_root(0);
713+
}

0 commit comments

Comments
 (0)