diff --git a/benches/boxed_uint.rs b/benches/boxed_uint.rs index e124afa2d..4bbdb7a00 100644 --- a/benches/boxed_uint.rs +++ b/benches/boxed_uint.rs @@ -61,6 +61,19 @@ fn bench_mul(c: &mut Criterion) { ) }); + group.bench_function("boxed_wrapping_mul", |b| { + b.iter_batched( + || { + ( + BoxedUint::random_bits(&mut OsRng, UINT_BITS), + BoxedUint::random_bits(&mut OsRng, UINT_BITS), + ) + }, + |(x, y)| black_box(x.wrapping_mul(&y)), + BatchSize::SmallInput, + ) + }); + group.bench_function("boxed_square", |b| { b.iter_batched( || BoxedUint::random_bits(&mut OsRng, UINT_BITS), @@ -68,6 +81,14 @@ fn bench_mul(c: &mut Criterion) { BatchSize::SmallInput, ) }); + + group.bench_function("boxed_wrapping_square", |b| { + b.iter_batched( + || BoxedUint::random_bits(&mut OsRng, UINT_BITS), + |x| black_box(x.wrapping_square()), + BatchSize::SmallInput, + ) + }); } fn bench_division(c: &mut Criterion) { diff --git a/benches/uint.rs b/benches/uint.rs index f02eb5e98..df2011d76 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -4,7 +4,7 @@ use criterion::{ }; use crypto_bigint::{ Gcd, Limb, NonZero, Odd, OddUint, Random, RandomBits, RandomMod, Reciprocal, U128, U256, U512, - U1024, U2048, U4096, Uint, + U1024, U2048, U4096, U8192, Uint, }; use rand_chacha::ChaCha8Rng; use rand_core::{RngCore, SeedableRng}; @@ -166,6 +166,38 @@ fn bench_mul(c: &mut Criterion) { ) }); + group.bench_function("widening_mul, U8192xU4096", |b| { + b.iter_batched( + || (U8192::random(&mut rng), U4096::random(&mut rng)), + |(x, y)| black_box(x.widening_mul(&y)), + BatchSize::SmallInput, + ) + }); + + group.bench_function("wrapping_mul, U256xU256", |b| { + b.iter_batched( + || (U256::random(&mut rng), U256::random(&mut rng)), + |(x, y)| black_box(x.wrapping_mul(&y)), + BatchSize::SmallInput, + ) + }); + + group.bench_function("wrapping_mul, U4096xU4096", |b| { + b.iter_batched( + || (U4096::random(&mut rng), U4096::random(&mut rng)), + |(x, y)| black_box(x.wrapping_mul(&y)), + BatchSize::SmallInput, + ) + }); + + group.bench_function("wrapping_mul, U8192xU4096", |b| { + b.iter_batched( + || (U8192::random(&mut rng), U4096::random(&mut rng)), + |(x, y)| black_box(x.wrapping_mul(&y)), + BatchSize::SmallInput, + ) + }); + group.bench_function("square_wide, U256", |b| { b.iter_batched( || U256::random(&mut rng), @@ -182,6 +214,22 @@ fn bench_mul(c: &mut Criterion) { ) }); + group.bench_function("wrapping_square, U256xU256", |b| { + b.iter_batched( + || U256::random(&mut rng), + |x| x.wrapping_square(), + BatchSize::SmallInput, + ) + }); + + group.bench_function("wrapping_square, U4096xU4096", |b| { + b.iter_batched( + || (U4096::random(&mut rng)), + |x| x.wrapping_square(), + BatchSize::SmallInput, + ) + }); + group.bench_function("mul_mod, U256", |b| { b.iter_batched( || { diff --git a/src/uint/boxed/mul.rs b/src/uint/boxed/mul.rs index c5b0e75e6..8b235320b 100644 --- a/src/uint/boxed/mul.rs +++ b/src/uint/boxed/mul.rs @@ -1,10 +1,10 @@ //! [`BoxedUint`] multiplication operations. use crate::{ - BoxedUint, CheckedMul, ConcatenatingMul, Limb, Resize, Uint, Wrapping, WrappingMul, Zero, + BoxedUint, CheckedMul, ConcatenatingMul, Limb, Uint, Wrapping, WrappingMul, Zero, uint::mul::{ karatsuba::{KARATSUBA_MIN_STARTING_LIMBS, karatsuba_mul_limbs, karatsuba_square_limbs}, - mul_limbs, square_limbs, + mul_limbs, schoolbook, square_limbs, }, }; use core::ops::{Mul, MulAssign}; @@ -45,7 +45,26 @@ impl BoxedUint { /// Perform wrapping multiplication, wrapping to the width of `self`. pub fn wrapping_mul(&self, rhs: &Self) -> Self { - self.mul(rhs).resize_unchecked(self.bits_precision()) + self.wrapping_mul_limbs(rhs.as_limbs()) + } + + #[inline(always)] + fn wrapping_mul_limbs(&self, rhs: &[Limb]) -> Self { + // Perform a widening Karatsuba multiplication and truncate + // for very large numbers, where the performance is better. + if self.nlimbs().min(rhs.len()) > 200 { + let size = self.nlimbs() + rhs.len(); + let overlap = self.nlimbs().min(rhs.len()); + let mut limbs = vec![Limb::ZERO; size + overlap * 2]; + let (out, scratch) = limbs.as_mut_slice().split_at_mut(size); + karatsuba_mul_limbs(&self.limbs, rhs, out, scratch); + limbs.truncate(self.nlimbs()); + return limbs.into(); + } + + let mut limbs = vec![Limb::ZERO; self.nlimbs()]; + schoolbook::wrapping_mul(&self.limbs, rhs, &mut limbs); + limbs.into() } /// Multiply `self` by itself. @@ -64,6 +83,13 @@ impl BoxedUint { square_limbs(&self.limbs, &mut limbs); limbs.into() } + + /// Multiply `self` by itself, wrapping to the width of `self`. + pub fn wrapping_square(&self) -> Self { + let mut limbs = vec![Limb::ZERO; self.nlimbs()]; + schoolbook::wrapping_square(&self.limbs, &mut limbs); + limbs.into() + } } impl CheckedMul for BoxedUint { @@ -196,19 +222,31 @@ mod tests { #[cfg(feature = "rand_core")] #[test] fn mul_cmp() { - use crate::RandomBits; + use crate::{RandomBits, Resize}; use rand_core::SeedableRng; let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1); for _ in 0..50 { let a = BoxedUint::random_bits(&mut rng, 4096); assert_eq!(a.mul(&a), a.square(), "a = {a}"); + assert_eq!(a.wrapping_mul(&a), a.wrapping_square(), "a = {a}"); } for _ in 0..50 { let a = BoxedUint::random_bits(&mut rng, 4096); let b = BoxedUint::random_bits(&mut rng, 5000); - assert_eq!(a.mul(&b), b.mul(&a), "a={a}, b={b}"); + let expect = a.mul(&b); + assert_eq!(b.mul(&a), expect, "a={a}, b={b}"); + assert_eq!( + a.wrapping_mul(&b), + expect.clone().resize_unchecked(a.bits_precision()), + "a={a}, b={b}" + ); + assert_eq!( + b.wrapping_mul(&a), + expect.clone().resize_unchecked(b.bits_precision()), + "a={a}, b={b}" + ); } } } diff --git a/src/uint/mul.rs b/src/uint/mul.rs index 666956b1c..34b67cd35 100644 --- a/src/uint/mul.rs +++ b/src/uint/mul.rs @@ -12,124 +12,7 @@ use crate::{ use self::karatsuba::UintKaratsubaMul; pub(crate) mod karatsuba; - -/// Schoolbook multiplication a.k.a. long multiplication, i.e. the traditional method taught in -/// schools. -/// -/// The most efficient method for small numbers. -#[inline(always)] -const fn schoolbook_multiplication(lhs: &[Limb], rhs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) { - if lhs.len() != lo.len() || rhs.len() != hi.len() { - panic!("schoolbook multiplication length mismatch"); - } - - let mut i = 0; - while i < lhs.len() { - let mut j = 0; - let mut carry = Limb::ZERO; - let xi = lhs[i]; - - while j < rhs.len() { - let k = i + j; - - if k >= lhs.len() { - (hi[k - lhs.len()], carry) = xi.carrying_mul_add(rhs[j], hi[k - lhs.len()], carry); - } else { - (lo[k], carry) = xi.carrying_mul_add(rhs[j], lo[k], carry); - } - - j += 1; - } - - if i + j >= lhs.len() { - hi[i + j - lhs.len()] = carry; - } else { - lo[i + j] = carry; - } - i += 1; - } -} - -/// Schoolbook method of squaring. -/// -/// Like schoolbook multiplication, but only considering half of the multiplication grid. -#[inline(always)] -pub(crate) const fn schoolbook_squaring(limbs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) { - // Translated from https://github.com/ucbrise/jedi-pairing/blob/c4bf151/include/core/bigint.hpp#L410 - // - // Permission to relicense the resulting translation as Apache 2.0 + MIT was given - // by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411 - - if limbs.len() != lo.len() || lo.len() != hi.len() { - panic!("schoolbook squaring length mismatch"); - } - - let mut i = 1; - while i < limbs.len() { - let mut j = 0; - let mut carry = Limb::ZERO; - let xi = limbs[i]; - - while j < i { - let k = i + j; - - if k >= limbs.len() { - (hi[k - limbs.len()], carry) = - xi.carrying_mul_add(limbs[j], hi[k - limbs.len()], carry); - } else { - (lo[k], carry) = xi.carrying_mul_add(limbs[j], lo[k], carry); - } - - j += 1; - } - - if (2 * i) < limbs.len() { - lo[2 * i] = carry; - } else { - hi[2 * i - limbs.len()] = carry; - } - - i += 1; - } - - // Double the current result, this accounts for the other half of the multiplication grid. - // The top word is empty, so we use a special purpose shl. - let mut carry = Limb::ZERO; - let mut i = 0; - while i < limbs.len() { - (lo[i].0, carry) = ((lo[i].0 << 1) | carry.0, lo[i].shr(Limb::BITS - 1)); - i += 1; - } - - let mut i = 0; - while i < limbs.len() - 1 { - (hi[i].0, carry) = ((hi[i].0 << 1) | carry.0, hi[i].shr(Limb::BITS - 1)); - i += 1; - } - hi[limbs.len() - 1] = carry; - - // Handle the diagonal of the multiplication grid, which finishes the multiplication grid. - let mut carry = Limb::ZERO; - let mut i = 0; - while i < limbs.len() { - let xi = limbs[i]; - if (i * 2) < limbs.len() { - (lo[i * 2], carry) = xi.carrying_mul_add(xi, lo[i * 2], carry); - } else { - (hi[i * 2 - limbs.len()], carry) = - xi.carrying_mul_add(xi, hi[i * 2 - limbs.len()], carry); - } - - if (i * 2 + 1) < limbs.len() { - (lo[i * 2 + 1], carry) = lo[i * 2 + 1].overflowing_add(carry); - } else { - (hi[i * 2 + 1 - limbs.len()], carry) = - hi[i * 2 + 1 - limbs.len()].overflowing_add(carry); - } - - i += 1; - } -} +pub(crate) mod schoolbook; impl Uint { /// Multiply `self` by `rhs`, returning a concatenated "wide" result. @@ -184,8 +67,15 @@ impl Uint { } /// Perform wrapping multiplication, discarding overflow. - pub const fn wrapping_mul(&self, rhs: &Uint) -> Self { - self.widening_mul(rhs).0 + pub const fn wrapping_mul(&self, rhs: &Uint) -> Self { + // A single special case that is faster for now + if LIMBS == RHS_LIMBS && LIMBS == 128 { + return self.widening_mul(rhs).0; + } + + let mut lo = Uint::ZERO; + schoolbook::wrapping_mul(&self.limbs, &rhs.limbs, &mut lo.limbs); + lo } /// Perform saturating multiplication, returning `MAX` on overflow. @@ -229,7 +119,9 @@ impl Uint { /// Perform wrapping square, discarding overflow. pub const fn wrapping_square(&self) -> Uint { - self.square_wide().0 + let mut lo = Uint::ZERO; + schoolbook::wrapping_square(&self.limbs, &mut lo.limbs); + lo } /// Perform saturating squaring, returning `MAX` on overflow. @@ -366,9 +258,9 @@ pub(crate) const fn uint_mul_limbs( rhs: &[Limb], ) -> (Uint, Uint) { debug_assert!(lhs.len() == LIMBS && rhs.len() == RHS_LIMBS); - let mut lo: Uint = Uint::::ZERO; + let mut lo = Uint::::ZERO; let mut hi = Uint::::ZERO; - schoolbook_multiplication(lhs, rhs, &mut lo.limbs, &mut hi.limbs); + schoolbook::mul_wide(lhs, rhs, &mut lo.limbs, &mut hi.limbs); (lo, hi) } @@ -379,7 +271,7 @@ pub(crate) const fn uint_square_limbs( ) -> (Uint, Uint) { let mut lo = Uint::::ZERO; let mut hi = Uint::::ZERO; - schoolbook_squaring(limbs, &mut lo.limbs, &mut hi.limbs); + schoolbook::square_wide(limbs, &mut lo.limbs, &mut hi.limbs); (lo, hi) } @@ -388,7 +280,7 @@ pub(crate) const fn uint_square_limbs( pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) { debug_assert_eq!(lhs.len() + rhs.len(), out.len()); let (lo, hi) = out.split_at_mut(lhs.len()); - schoolbook_multiplication(lhs, rhs, lo, hi); + schoolbook::mul_wide(lhs, rhs, lo, hi); } /// Wrapper function used by `BoxedUint` @@ -396,7 +288,7 @@ pub(crate) fn mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) { pub(crate) fn square_limbs(limbs: &[Limb], out: &mut [Limb]) { debug_assert_eq!(limbs.len() * 2, out.len()); let (lo, hi) = out.split_at_mut(limbs.len()); - schoolbook_squaring(limbs, lo, hi); + schoolbook::square_wide(limbs, lo, hi); } #[cfg(test)] @@ -421,6 +313,7 @@ mod tests { let expected = U64::from_u64(a_int as u64 * b_int as u64); assert_eq!(lo, expected); assert!(bool::from(hi.is_zero())); + assert_eq!(lo, U64::from_u32(a_int).wrapping_mul(&U64::from_u32(b_int))); } } } @@ -443,8 +336,26 @@ mod tests { fn mul_concat_mixed() { let a = U64::from_u64(0x0011223344556677); let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff); - assert_eq!(a.concatenating_mul(&b), U192::from(&a).saturating_mul(&b)); - assert_eq!(b.concatenating_mul(&a), U192::from(&b).saturating_mul(&a)); + let expected = U192::from(&b).saturating_mul(&a); + assert_eq!(a.concatenating_mul(&b), expected); + assert_eq!(b.concatenating_mul(&a), expected); + } + + #[test] + fn wrapping_mul_even() { + assert_eq!(U64::ZERO.wrapping_mul(&U64::MAX), U64::ZERO); + assert_eq!(U64::MAX.wrapping_mul(&U64::ZERO), U64::ZERO); + assert_eq!(U64::MAX.wrapping_mul(&U64::MAX), U64::ONE); + assert_eq!(U64::ONE.wrapping_mul(&U64::MAX), U64::MAX); + } + + #[test] + fn wrapping_mul_mixed() { + let a = U64::from_u64(0x0011223344556677); + let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff); + let expected = U192::from(&b).saturating_mul(&a); + assert_eq!(b.wrapping_mul(&a), expected.resize()); + assert_eq!(a.wrapping_mul(&b), expected.resize()); } #[test] diff --git a/src/uint/mul/karatsuba.rs b/src/uint/mul/karatsuba.rs index 1a4ff7ccd..a103143bf 100644 --- a/src/uint/mul/karatsuba.rs +++ b/src/uint/mul/karatsuba.rs @@ -22,7 +22,7 @@ use super::{uint_mul_limbs, uint_square_limbs}; use crate::{ConstChoice, Limb, Uint}; #[cfg(feature = "alloc")] -use super::square_limbs; +use super::{schoolbook, square_limbs}; #[cfg(feature = "alloc")] use crate::{WideWord, Word}; @@ -189,7 +189,7 @@ pub(crate) fn karatsuba_mul_limbs( }; if size <= KARATSUBA_MAX_REDUCE_LIMBS { out.fill(Limb::ZERO); - carrying_add_mul_limbs(lhs, rhs, out); + schoolbook::carrying_add_mul(lhs, rhs, out); return; } if lhs.len() + rhs.len() != out.len() || scratch.len() < 2 * size { @@ -278,11 +278,11 @@ pub(crate) fn karatsuba_mul_limbs( // Handle trailing limbs if !xt.is_empty() { - carrying_add_mul_limbs(xt, rhs, &mut out[size..]); + schoolbook::carrying_add_mul(xt, rhs, &mut out[size..]); } if !yt.is_empty() { let end_pos = 2 * size + yt.len(); - carry = carrying_add_mul_limbs(yt, x, &mut out[size..end_pos]); + carry = schoolbook::carrying_add_mul(yt, x, &mut out[size..end_pos]); i = end_pos; while i < out.len() { (out[i], carry) = out[i].carrying_add(Limb::ZERO, carry); @@ -387,33 +387,5 @@ fn conditional_wrapping_neg_assign(limbs: &mut [Limb], choice: ConstChoice) { } } -/// Add the schoolbook product of two limb slices to a limb slice, returning the carry. -#[cfg(feature = "alloc")] -fn carrying_add_mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) -> Limb { - if lhs.len() + rhs.len() != out.len() { - panic!("carrying_add_mul_limbs length mismatch"); - } - - let mut carry = Limb::ZERO; - let mut i = 0; - while i < lhs.len() { - let mut j = 0; - let mut carry2 = Limb::ZERO; - let xi = lhs[i]; - - while j < rhs.len() { - let k = i + j; - (out[k], carry2) = xi.carrying_mul_add(rhs[j], out[k], carry2); - j += 1; - } - - carry = carry.wrapping_add(carry2); - (out[i + j], carry) = out[i + j].carrying_add(Limb::ZERO, carry); - i += 1; - } - - carry -} - impl_uint_karatsuba_multiplication!(128, 64, 32, 16, 8); impl_uint_karatsuba_squaring!(128, 64, 32); diff --git a/src/uint/mul/schoolbook.rs b/src/uint/mul/schoolbook.rs new file mode 100644 index 000000000..c6e81c41e --- /dev/null +++ b/src/uint/mul/schoolbook.rs @@ -0,0 +1,223 @@ +use crate::Limb; + +/// Schoolbook multiplication a.k.a. long multiplication, i.e. the traditional method taught in +/// schools. +/// +/// The most efficient method for small numbers. +#[inline(always)] +pub const fn mul_wide(lhs: &[Limb], rhs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) { + assert!( + lhs.len() == lo.len() && rhs.len() == hi.len(), + "schoolbook multiplication length mismatch" + ); + + let mut i = 0; + + while i < lhs.len() { + let mut carry = Limb::ZERO; + let xi = lhs[i]; + let mut j = 0; + + while j < rhs.len() { + let k = i + j; + + if k >= lhs.len() { + (hi[k - lhs.len()], carry) = xi.carrying_mul_add(rhs[j], hi[k - lhs.len()], carry); + } else { + (lo[k], carry) = xi.carrying_mul_add(rhs[j], lo[k], carry); + } + + j += 1; + } + + if i + j >= lhs.len() { + hi[i + j - lhs.len()] = carry; + } else { + lo[i + j] = carry; + } + i += 1; + } +} + +/// Add the schoolbook product of two limb slices to a limb slice, returning the carry. +#[cfg(feature = "alloc")] +#[inline] +pub const fn carrying_add_mul(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) -> Limb { + assert!( + lhs.len() + rhs.len() == out.len(), + "carrying_add_mul length mismatch" + ); + + let mut carry = Limb::ZERO; + let mut i = 0; + + while i < lhs.len() { + let mut carry2 = Limb::ZERO; + let xi = lhs[i]; + let mut j = 0; + + while j < rhs.len() { + let k = i + j; + (out[k], carry2) = xi.carrying_mul_add(rhs[j], out[k], carry2); + j += 1; + } + + carry = carry.wrapping_add(carry2); + (out[i + j], carry) = out[i + j].carrying_add(Limb::ZERO, carry); + i += 1; + } + + carry +} + +/// Schoolbook multiplication which only calculates the lower limbs of the product. +#[inline(always)] +pub const fn wrapping_mul(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) { + assert!( + lhs.len() == out.len(), + "wrapping schoolbook multiplication length mismatch" + ); + + let mut i = 0; + + while i < lhs.len() { + let mut carry = Limb::ZERO; + let xi = lhs[i]; + let mut k = i; + + while k < lhs.len() { + let j = k - i; + if j == rhs.len() { + out[k] = carry; + break; + } + (out[k], carry) = xi.carrying_mul_add(rhs[j], out[k], carry); + k += 1; + } + i += 1; + } +} + +/// Schoolbook method of squaring. +/// +/// Like schoolbook multiplication, but only considering half of the multiplication grid. +#[inline(always)] +pub const fn square_wide(limbs: &[Limb], lo: &mut [Limb], hi: &mut [Limb]) { + // Translated from https://github.com/ucbrise/jedi-pairing/blob/c4bf151/include/core/bigint.hpp#L410 + // + // Permission to relicense the resulting translation as Apache 2.0 + MIT was given + // by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411 + + assert!( + limbs.len() == lo.len() && lo.len() == hi.len(), + "schoolbook squaring length mismatch" + ); + + let mut i = 1; + while i < limbs.len() { + let mut j = 0; + let mut carry = Limb::ZERO; + let xi = limbs[i]; + + while j < i { + let k = i + j; + + if k >= limbs.len() { + (hi[k - limbs.len()], carry) = + xi.carrying_mul_add(limbs[j], hi[k - limbs.len()], carry); + } else { + (lo[k], carry) = xi.carrying_mul_add(limbs[j], lo[k], carry); + } + + j += 1; + } + + if (2 * i) < limbs.len() { + lo[2 * i] = carry; + } else { + hi[2 * i - limbs.len()] = carry; + } + + i += 1; + } + + // Double the current result, this accounts for the other half of the multiplication grid. + // The top word is empty, so we use a special purpose shl. + let mut carry = Limb::ZERO; + let mut i = 0; + while i < limbs.len() { + (lo[i].0, carry) = ((lo[i].0 << 1) | carry.0, lo[i].shr(Limb::BITS - 1)); + i += 1; + } + + let mut i = 0; + while i < limbs.len() - 1 { + (hi[i].0, carry) = ((hi[i].0 << 1) | carry.0, hi[i].shr(Limb::BITS - 1)); + i += 1; + } + hi[limbs.len() - 1] = carry; + + // Handle the diagonal of the multiplication grid, which finishes the multiplication grid. + let mut carry = Limb::ZERO; + let mut i = 0; + while i < limbs.len() { + let xi = limbs[i]; + if (i * 2) < limbs.len() { + (lo[i * 2], carry) = xi.carrying_mul_add(xi, lo[i * 2], carry); + } else { + (hi[i * 2 - limbs.len()], carry) = + xi.carrying_mul_add(xi, hi[i * 2 - limbs.len()], carry); + } + + if (i * 2 + 1) < limbs.len() { + (lo[i * 2 + 1], carry) = lo[i * 2 + 1].overflowing_add(carry); + } else { + (hi[i * 2 + 1 - limbs.len()], carry) = + hi[i * 2 + 1 - limbs.len()].overflowing_add(carry); + } + + i += 1; + } +} + +/// Schoolbook squaring which only calculates the lower limbs of the product. +#[inline(always)] +pub const fn wrapping_square(limbs: &[Limb], out: &mut [Limb]) { + assert!( + limbs.len() == out.len(), + "schoolbook wrapping squaring length mismatch" + ); + + let mut i = 1; + + while i < limbs.len() { + let mut carry = Limb::ZERO; + let xi = limbs[i]; + let mut k = i; + + while k < 2 * i && k < limbs.len() { + (out[k], carry) = xi.carrying_mul_add(limbs[k - i], out[k], carry); + k += 1; + } + + if k < limbs.len() { + out[k] = carry; + } + i += 1; + } + + // Double the current result and fill in the diagonal terms. + let mut carry = Limb::ZERO; + let mut limb; + let mut hi_bit = Limb::ZERO; + i = 0; + while i < limbs.len() { + (limb, hi_bit) = (out[i].shl(1).bitor(hi_bit), out[i].shr(Limb::HI_BIT)); + (out[i], carry) = if i & 1 == 0 { + limbs[i / 2].carrying_mul_add(limbs[i / 2], limb, carry) + } else { + limb.overflowing_add(carry) + }; + i += 1; + } +} diff --git a/tests/boxed_uint.rs b/tests/boxed_uint.rs index b46556e6b..48fcd4d80 100644 --- a/tests/boxed_uint.rs +++ b/tests/boxed_uint.rs @@ -207,6 +207,39 @@ proptest! { prop_assert_eq!(expected, to_biguint(&actual)); } + #[test] + fn widening_square(a in uint()) { + let a_bi = to_biguint(&a); + + let expected = a_bi.pow(2); + let actual = a.square(); + + prop_assert_eq!(expected, to_biguint(&actual)); + } + + #[test] + fn wrapping_mul(a in uint(), b in uint()) { + let a_bi = to_biguint(&a); + let b_bi = to_biguint(&b); + + let cap = BigUint::from(2u32).pow(a.bits_precision()); + let expected = (a_bi * b_bi) % cap; + let actual = a.wrapping_mul(&b); + + prop_assert_eq!(expected, to_biguint(&actual)); + } + + #[test] + fn wrapping_square(a in uint()) { + let a_bi = to_biguint(&a); + + let cap = BigUint::from(2u32).pow(a.bits_precision()); + let expected = a_bi.pow(2) % cap; + let actual = a.wrapping_square(); + + prop_assert_eq!(expected, to_biguint(&actual)); + } + #[test] fn rem((a, b) in uint_pair()) { if bool::from(!b.is_zero()) { diff --git a/tests/uint.rs b/tests/uint.rs index cf94dc3fe..1a71131a1 100644 --- a/tests/uint.rs +++ b/tests/uint.rs @@ -244,14 +244,27 @@ proptest! { } #[test] - fn wrapping_mul(a in uint(), b in uint()) { + fn wrapping_mul(a in uint(), b in uint_large()) { let a_bi = to_biguint(&a); let b_bi = to_biguint(&b); - let expected = to_uint(a_bi * b_bi); - let actual = a.wrapping_mul(&b); + let expected_a = to_uint(&a_bi * &b_bi); + let expected_b = to_uint_large(b_bi * a_bi); + let actual_a = a.wrapping_mul(&b); + let actual_b = b.wrapping_mul(&a); - prop_assert_eq!(expected, actual); + prop_assert_eq!(expected_a, actual_a); + prop_assert_eq!(expected_b, actual_b); + } + + #[test] + fn wrapping_square(a in uint()) { + let a_bi = to_biguint(&a); + + let expected = to_uint(&a_bi * &a_bi); + let actual = a.wrapping_square(); + + assert_eq!(expected, actual); } #[test] @@ -293,7 +306,6 @@ proptest! { assert_eq!(expected, actual); } - #[test] fn square_large(a in uint_large()) { let a_bi = to_biguint(&a);