diff --git a/src/const_choice.rs b/src/const_choice.rs index bd070f4bb..5152ff09d 100644 --- a/src/const_choice.rs +++ b/src/const_choice.rs @@ -1,6 +1,6 @@ use subtle::{Choice, CtOption}; -use crate::{modular::BernsteinYangInverter, NonZero, Odd, Uint, Word}; +use crate::{modular::BernsteinYangInverter, Limb, NonZero, Odd, Uint, Word}; /// A boolean value returned by constant-time `const fn`s. // TODO: should be replaced by `subtle::Choice` or `CtOption` @@ -319,6 +319,20 @@ impl ConstCtOption>> { } } +impl ConstCtOption> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> NonZero { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + impl ConstCtOption> { diff --git a/src/modular/boxed_monty_form.rs b/src/modular/boxed_monty_form.rs index c3395f24b..fec97acf1 100644 --- a/src/modular/boxed_monty_form.rs +++ b/src/modular/boxed_monty_form.rs @@ -1,4 +1,3 @@ -//! Implements `BoxedMontyForm`s, supporting modular arithmetic with a modulus whose size and value //! is chosen at runtime. mod add; @@ -9,6 +8,7 @@ mod pow; mod sub; use super::{ + div_by_2, reduction::{montgomery_reduction_boxed, montgomery_reduction_boxed_mut}, Retrieve, }; @@ -219,6 +219,18 @@ impl BoxedMontyForm { debug_assert!(self.montgomery_form < self.params.modulus); self.montgomery_form.clone() } + + /// Performs the modular division by 2, that is for given `x` returns `y` + /// such that `y * 2 = x mod p`. This means: + /// - if `x` is even, returns `x / 2`, + /// - if `x` is odd, returns `(x + p) / 2` + /// (since the modulus `p` in Montgomery form is always odd, this divides entirely). + pub fn div_by_2(&self) -> Self { + Self { + montgomery_form: div_by_2::div_by_2_boxed(&self.montgomery_form, &self.params.modulus), + params: self.params.clone(), + } + } } impl Retrieve for BoxedMontyForm { @@ -258,3 +270,26 @@ fn convert_to_montgomery(integer: &mut BoxedUint, params: &BoxedMontyParams) { #[cfg(feature = "zeroize")] product.zeroize(); } + +#[cfg(test)] +mod tests { + use super::{BoxedMontyForm, BoxedMontyParams, BoxedUint, Odd}; + + #[test] + fn new_params_with_valid_modulus() { + let modulus = Odd::new(BoxedUint::from(3u8)).unwrap(); + BoxedMontyParams::new(modulus); + } + + #[test] + fn div_by_2() { + let modulus = Odd::new(BoxedUint::from(9u8)).unwrap(); + let params = BoxedMontyParams::new(modulus); + let zero = BoxedMontyForm::zero(params.clone()); + let one = BoxedMontyForm::one(params.clone()); + let two = one.add(&one); + + assert_eq!(zero.div_by_2(), zero); + assert_eq!(one.div_by_2().mul(&two), one); + } +} diff --git a/src/modular/div_by_2.rs b/src/modular/div_by_2.rs index 1ad53b2a6..71c646c59 100644 --- a/src/modular/div_by_2.rs +++ b/src/modular/div_by_2.rs @@ -1,4 +1,6 @@ use crate::Uint; +#[cfg(feature = "alloc")] +use crate::{BoxedUint, ConstantTimeSelect}; pub(crate) fn div_by_2(a: &Uint, modulus: &Uint) -> Uint { // We are looking for such `x` that `x * 2 = y mod modulus`, @@ -28,3 +30,19 @@ pub(crate) fn div_by_2(a: &Uint, modulus: &Uint::select(&if_even, &if_odd, is_odd) } + +#[cfg(feature = "alloc")] +pub(crate) fn div_by_2_boxed(a: &BoxedUint, modulus: &BoxedUint) -> BoxedUint { + debug_assert_eq!(a.bits_precision(), modulus.bits_precision()); + + let (mut half, is_odd) = a.shr1_with_carry(); + let half_modulus = modulus.shr1(); + + let if_odd = half + .wrapping_add(&half_modulus) + .wrapping_add(&BoxedUint::one_with_precision(a.bits_precision())); + + half.ct_assign(&if_odd, is_odd); + + half +} diff --git a/src/odd.rs b/src/odd.rs index 77e3f619c..608a0d807 100644 --- a/src/odd.rs +++ b/src/odd.rs @@ -143,3 +143,27 @@ impl Odd { Odd(ret) } } + +#[cfg(test)] +mod tests { + #[cfg(feature = "alloc")] + use super::BoxedUint; + use super::{Odd, Uint}; + + #[test] + fn not_odd_numbers() { + let zero = Odd::new(Uint::<4>::ZERO); + assert!(bool::from(zero.is_none())); + let two = Odd::new(Uint::<4>::from(2u8)); + assert!(bool::from(two.is_none())); + } + + #[cfg(feature = "alloc")] + #[test] + fn not_odd_numbers_boxed() { + let zero = Odd::new(BoxedUint::zero()); + assert!(bool::from(zero.is_none())); + let two = Odd::new(BoxedUint::from(2u8)); + assert!(bool::from(two.is_none())); + } +} diff --git a/src/uint/boxed.rs b/src/uint/boxed.rs index dc210ba0f..ddcc5fa7b 100644 --- a/src/uint/boxed.rs +++ b/src/uint/boxed.rs @@ -10,6 +10,7 @@ mod bits; mod cmp; mod ct; mod div; +mod div_limb; pub(crate) mod encoding; mod from; mod inv_mod; @@ -19,6 +20,7 @@ mod neg; mod neg_mod; mod shl; mod shr; +mod sqrt; mod sub; mod sub_mod; @@ -92,6 +94,14 @@ impl BoxedUint { .fold(Choice::from(1), |acc, limb| acc & limb.is_zero()) } + /// Is this [`BoxedUint`] not equal to zero? + pub fn is_nonzero(&self) -> Choice { + // TODO: why not just !self.is_zero()? + self.limbs + .iter() + .fold(Choice::from(0), |acc, limb| acc | limb.is_nonzero().into()) + } + /// Is this [`BoxedUint`] equal to one? pub fn is_one(&self) -> Choice { let mut iter = self.limbs.iter(); diff --git a/src/uint/boxed/bits.rs b/src/uint/boxed/bits.rs index e60ebae09..f59d52a8d 100644 --- a/src/uint/boxed/bits.rs +++ b/src/uint/boxed/bits.rs @@ -1,6 +1,6 @@ //! Bit manipulation functions. -use crate::{BoxedUint, Limb, Zero}; +use crate::{BoxedUint, ConstChoice, Limb, Zero}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; impl BoxedUint { @@ -24,6 +24,11 @@ impl BoxedUint { Limb::BITS * n - leading_zeros } + /// `floor(log2(self.bits_precision()))`. + pub(crate) fn log2_bits(&self) -> u32 { + u32::BITS - self.bits_precision().leading_zeros() - 1 + } + /// Calculate the number of bits needed to represent this number in variable-time with respect /// to `self`. pub fn bits_vartime(&self) -> u32 { @@ -36,6 +41,18 @@ impl BoxedUint { Limb::BITS * (i as u32 + 1) - limb.leading_zeros() } + /// Returns `true` if the bit at position `index` is set, `false` otherwise. + /// + /// # Remarks + /// This operation is variable time with respect to `index` only. + pub fn bit_vartime(&self, index: u32) -> bool { + if index >= self.bits_precision() { + false + } else { + (self.limbs[(index / Limb::BITS) as usize].0 >> (index % Limb::BITS)) & 1 == 1 + } + } + /// Get the precision of this [`BoxedUint`] in bits. pub fn bits_precision(&self) -> u32 { self.limbs.len() as u32 * Limb::BITS @@ -55,6 +72,45 @@ impl BoxedUint { count } + /// Calculate the number of trailing ones in the binary representation of this number. + pub fn trailing_ones(&self) -> u32 { + let limbs = self.as_limbs(); + + let mut count = 0; + let mut i = 0; + let mut nonmax_limb_not_encountered = ConstChoice::TRUE; + while i < limbs.len() { + let l = limbs[i]; + let z = l.trailing_ones(); + count += nonmax_limb_not_encountered.if_true_u32(z); + nonmax_limb_not_encountered = + nonmax_limb_not_encountered.and(ConstChoice::from_word_eq(l.0, Limb::MAX.0)); + i += 1; + } + + count + } + + /// Calculate the number of trailing ones in the binary representation of this number, + /// variable time in `self`. + pub fn trailing_ones_vartime(&self) -> u32 { + let limbs = self.as_limbs(); + + let mut count = 0; + let mut i = 0; + while i < limbs.len() { + let l = limbs[i]; + let z = l.trailing_ones(); + count += z; + if z != Limb::BITS { + break; + } + i += 1; + } + + count + } + /// Sets the bit at `index` to 0 or 1 depending on the value of `bit_value`. pub(crate) fn set_bit(&mut self, index: u32, bit_value: Choice) { let limb_num = (index / Limb::BITS) as usize; @@ -89,6 +145,18 @@ mod tests { result } + #[test] + fn bit_vartime() { + let u = uint_with_bits_at(&[16, 48, 112, 127, 255]); + assert!(!u.bit_vartime(0)); + assert!(!u.bit_vartime(1)); + assert!(u.bit_vartime(16)); + assert!(u.bit_vartime(127)); + assert!(u.bit_vartime(255)); + assert!(!u.bit_vartime(256)); + assert!(!u.bit_vartime(260)); + } + #[test] fn bits() { assert_eq!(0, BoxedUint::zero().bits()); @@ -119,4 +187,40 @@ mod tests { u.set_bit(150, Choice::from(0)); assert_eq!(u, uint_with_bits_at(&[16, 79])); } + + #[test] + fn trailing_ones() { + let u = !uint_with_bits_at(&[16, 79, 150]); + assert_eq!(u.trailing_ones(), 16); + + let u = !uint_with_bits_at(&[79, 150]); + assert_eq!(u.trailing_ones(), 79); + + let u = !uint_with_bits_at(&[150, 207]); + assert_eq!(u.trailing_ones(), 150); + + let u = !uint_with_bits_at(&[0, 150, 207]); + assert_eq!(u.trailing_ones(), 0); + + let u = !BoxedUint::zero_with_precision(256); + assert_eq!(u.trailing_ones(), 256); + } + + #[test] + fn trailing_ones_vartime() { + let u = !uint_with_bits_at(&[16, 79, 150]); + assert_eq!(u.trailing_ones_vartime(), 16); + + let u = !uint_with_bits_at(&[79, 150]); + assert_eq!(u.trailing_ones_vartime(), 79); + + let u = !uint_with_bits_at(&[150, 207]); + assert_eq!(u.trailing_ones_vartime(), 150); + + let u = !uint_with_bits_at(&[0, 150, 207]); + assert_eq!(u.trailing_ones_vartime(), 0); + + let u = !BoxedUint::zero_with_precision(256); + assert_eq!(u.trailing_ones_vartime(), 256); + } } diff --git a/src/uint/boxed/cmp.rs b/src/uint/boxed/cmp.rs index 14bef8d1f..58187a70b 100644 --- a/src/uint/boxed/cmp.rs +++ b/src/uint/boxed/cmp.rs @@ -10,6 +10,30 @@ use subtle::{ Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess, }; +impl BoxedUint { + /// Returns the Ordering between `self` and `rhs` in variable time. + pub fn cmp_vartime(&self, rhs: &Self) -> Ordering { + debug_assert_eq!(self.limbs.len(), rhs.limbs.len()); + let mut i = self.limbs.len() - 1; + loop { + // TODO: investigate if directly comparing limbs is faster than performing a + // subtraction between limbs + let (val, borrow) = self.limbs[i].sbb(rhs.limbs[i], Limb::ZERO); + if val.0 != 0 { + return if borrow.0 != 0 { + Ordering::Less + } else { + Ordering::Greater + }; + } + if i == 0 { + return Ordering::Equal; + } + i -= 1; + } + } +} + impl ConstantTimeEq for BoxedUint { #[inline] fn ct_eq(&self, other: &Self) -> Choice { diff --git a/src/uint/boxed/div.rs b/src/uint/boxed/div.rs index 1cf5ac8a4..d230fc1d3 100644 --- a/src/uint/boxed/div.rs +++ b/src/uint/boxed/div.rs @@ -1,10 +1,23 @@ //! [`BoxedUint`] division operations. -use crate::{BoxedUint, CheckedDiv, ConstantTimeSelect, Limb, NonZero, Wrapping}; +use crate::{ + uint::boxed, BoxedUint, CheckedDiv, ConstantTimeSelect, Limb, NonZero, Reciprocal, Wrapping, +}; use core::ops::{Div, DivAssign, Rem, RemAssign}; use subtle::{Choice, ConstantTimeEq, ConstantTimeLess, CtOption}; impl BoxedUint { + /// Computes `self` / `rhs` using a pre-made reciprocal, + /// returns the quotient (q) and remainder (r). + pub fn div_rem_limb_with_reciprocal(&self, reciprocal: &Reciprocal) -> (Self, Limb) { + boxed::div_limb::div_rem_limb_with_reciprocal(self, reciprocal) + } + + /// Computes `self` / `rhs`, returns the quotient (q) and remainder (r). + pub fn div_rem_limb(&self, rhs: NonZero) -> (Self, Limb) { + boxed::div_limb::div_rem_limb_with_reciprocal(self, &Reciprocal::new(rhs)) + } + /// Computes self / rhs, returns the quotient, remainder. pub fn div_rem(&self, rhs: &NonZero) -> (Self, Self) { // Since `rhs` is nonzero, this should always hold. @@ -61,6 +74,14 @@ impl BoxedUint { self.div_rem(rhs).0 } + /// Wrapped division is just normal division i.e. `self` / `rhs` + /// + /// There’s no way wrapping could ever happen. + /// This function exists, so that all operations are accounted for in the wrapping operations + pub fn wrapping_div_vartime(&self, rhs: &NonZero) -> Self { + self.div_rem_vartime(rhs).0 + } + /// Perform checked division, returning a [`CtOption`] which `is_some` /// only if the rhs != 0 pub fn checked_div(&self, rhs: &Self) -> CtOption { diff --git a/src/uint/boxed/div_limb.rs b/src/uint/boxed/div_limb.rs new file mode 100644 index 000000000..4e4c7eeb4 --- /dev/null +++ b/src/uint/boxed/div_limb.rs @@ -0,0 +1,25 @@ +//! Implementation of constant-time division via reciprocal precomputation, as described in +//! "Improved Division by Invariant Integers" by Niels Möller and Torbjorn Granlund +//! (DOI: 10.1109/TC.2010.143, ). +use crate::{uint::div_limb::div2by1, BoxedUint, Limb, Reciprocal}; + +/// Divides `u` by the divisor encoded in the `reciprocal`, and returns +/// the quotient and the remainder. +#[inline(always)] +pub(crate) fn div_rem_limb_with_reciprocal( + u: &BoxedUint, + reciprocal: &Reciprocal, +) -> (BoxedUint, Limb) { + let (u_shifted, u_hi) = u.shl_limb(reciprocal.shift()); + let mut r = u_hi.0; + let mut q = vec![Limb::ZERO; u.limbs.len()]; + + let mut j = u.limbs.len(); + while j > 0 { + j -= 1; + let (qj, rj) = div2by1(r, u_shifted.as_limbs()[j].0, reciprocal); + q[j] = Limb(qj); + r = rj; + } + (BoxedUint { limbs: q.into() }, Limb(r >> reciprocal.shift())) +} diff --git a/src/uint/boxed/encoding.rs b/src/uint/boxed/encoding.rs index a5bd54788..1baae6adf 100644 --- a/src/uint/boxed/encoding.rs +++ b/src/uint/boxed/encoding.rs @@ -1,9 +1,10 @@ //! Const-friendly decoding operations for [`BoxedUint`]. use super::BoxedUint; -use crate::Limb; +use crate::{uint::encoding, Limb, Word}; use alloc::boxed::Box; use core::fmt; +use subtle::{Choice, CtOption}; /// Decoding errors for [`BoxedUint`]. #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -131,6 +132,37 @@ impl BoxedUint { out.into() } + + /// Create a new [`BoxedUint`] from the provided big endian hex string. + pub fn from_be_hex(hex: &str, bits_precision: u32) -> CtOption { + let nlimbs = (bits_precision / Limb::BITS) as usize; + let bytes = hex.as_bytes(); + + assert!( + bytes.len() == Limb::BYTES * nlimbs * 2, + "hex string is not the expected size" + ); + + let mut res = vec![Limb::ZERO; nlimbs]; + let mut buf = [0u8; Limb::BYTES]; + let mut i = 0; + let mut err = 0; + + while i < nlimbs { + let mut j = 0; + while j < Limb::BYTES { + let offset = (i * Limb::BYTES + j) * 2; + let (result, byte_err) = + encoding::decode_hex_byte([bytes[offset], bytes[offset + 1]]); + err |= byte_err; + buf[j] = result; + j += 1; + } + res[nlimbs - i - 1] = Limb(Word::from_be_bytes(buf)); + i += 1; + } + CtOption::new(Self { limbs: res.into() }, Choice::from((err == 0) as u8)) + } } #[cfg(test)] @@ -158,6 +190,17 @@ mod tests { ); } + #[test] + #[cfg(target_pointer_width = "64")] + fn from_be_hex_eq() { + let hex = "00112233445566778899aabbccddeeff"; + let n = BoxedUint::from_be_hex(hex, 128).unwrap(); + assert_eq!( + n.as_limbs(), + &[Limb(0x8899aabbccddeeff), Limb(0x0011223344556677)] + ); + } + #[test] #[cfg(target_pointer_width = "32")] fn from_be_slice_short() { diff --git a/src/uint/boxed/shl.rs b/src/uint/boxed/shl.rs index 6b7ec99b6..44234d4dd 100644 --- a/src/uint/boxed/shl.rs +++ b/src/uint/boxed/shl.rs @@ -1,6 +1,6 @@ //! [`BoxedUint`] bitwise left shift operations. -use crate::{BoxedUint, ConstantTimeSelect, Limb, WrappingShl, Zero}; +use crate::{BoxedUint, ConstChoice, ConstantTimeSelect, Limb, Word, WrappingShl, Zero}; use core::ops::{Shl, ShlAssign}; use subtle::{Choice, ConstantTimeLess}; @@ -142,6 +142,38 @@ impl BoxedUint { carry = new_carry } } + + /// Computes `self << shift` where `0 <= shift < Limb::BITS`, + /// returning the result and the carry. + pub(crate) fn shl_limb(&self, shift: u32) -> (Self, Limb) { + let mut limbs = vec![Limb::ZERO; self.limbs.len()]; + + let nz = ConstChoice::from_u32_nonzero(shift); + let lshift = shift; + let rshift = nz.if_true_u32(Limb::BITS - shift); + let carry = nz.if_true_word( + self.limbs[self.limbs.len() - 1] + .0 + .wrapping_shr(Word::BITS - shift), + ); + + limbs[0] = Limb(self.limbs[0].0 << lshift); + let mut i = 1; + while i < self.limbs.len() { + let mut limb = self.limbs[i].0 << lshift; + let hi = self.limbs[i - 1].0 >> rshift; + limb |= nz.if_true_word(hi); + limbs[i] = Limb(limb); + i += 1 + } + + ( + BoxedUint { + limbs: limbs.into(), + }, + Limb(carry), + ) + } } macro_rules! impl_shl { diff --git a/src/uint/boxed/sqrt.rs b/src/uint/boxed/sqrt.rs new file mode 100644 index 000000000..178d506ec --- /dev/null +++ b/src/uint/boxed/sqrt.rs @@ -0,0 +1,303 @@ +//! [`BoxedUint`] square root operations. + +use subtle::{ConstantTimeEq, ConstantTimeGreater, CtOption}; + +use crate::{BoxedUint, ConstantTimeSelect, NonZero}; + +impl BoxedUint { + /// Computes √(`self`) in constant time. + /// + /// Callers can check if `self` is a square by squaring the result + pub fn sqrt(&self) -> Self { + // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13. + // + // See Hast, "Note on computation of integer square roots" + // for the proof of the sufficiency of the bound on iterations. + // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf + + // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = + Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`) + + // Repeat enough times to guarantee result has stabilized. + let mut i = 0; + // TODO: avoid this clone + let mut x_prev = x.clone(); // keep the previous iteration in case we need to roll back. + + // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough. + while i < self.log2_bits() + 2 { + x_prev = x.clone(); // TODO: can we avoid this clone? + + // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` + + let (nz_x, is_nonzero) = (NonZero(x.clone()), x.is_nonzero()); + let (q, _) = self.div_rem(&nz_x); + + // A protection in case `self == 0`, which will make `x == 0` + let q = Self::ct_select( + &Self::zero_with_precision(self.bits_precision()), + &q, + is_nonzero, + ); + + x = x.wrapping_add(&q).shr1(); + i += 1; + } + + // At this point `x_prev == x_{n}` and `x == x_{n+1}` + // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`. + // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`. + Self::ct_select(&x_prev, &x, Self::ct_gt(&x_prev, &x)) + } + + /// Computes √(`self`) + /// + /// Callers can check if `self` is a square by squaring the result + pub fn sqrt_vartime(&self) -> Self { + // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 + + // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. + // Will not overflow since `b <= BITS`. + let (mut x, _overflow) = + Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`) + + // Stop right away if `x` is zero to avoid divizion by zero. + while !x + .cmp_vartime(&Self::zero_with_precision(self.bits_precision())) + .is_eq() + { + // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` + let q = + self.wrapping_div_vartime(&NonZero::::new(x.clone()).expect("Division by 0")); + let t = x.wrapping_add(&q); + let next_x = t.shr1(); + + // If `next_x` is the same as `x` or greater, we reached convergence + // (`x` is guaranteed to either go down or oscillate between + // `sqrt(self)` and `sqrt(self) + 1`) + if !x.cmp_vartime(&next_x).is_gt() { + break; + } + + x = next_x; + } + + if self.is_nonzero().into() { + x + } else { + Self::zero_with_precision(self.bits_precision()) + } + } + + /// Wrapped sqrt is just normal √(`self`) + /// There’s no way wrapping could ever happen. + /// This function exists so that all operations are accounted for in the wrapping operations. + pub fn wrapping_sqrt(&self) -> Self { + self.sqrt() + } + + /// Wrapped sqrt is just normal √(`self`) + /// There’s no way wrapping could ever happen. + /// This function exists so that all operations are accounted for in the wrapping operations. + pub fn wrapping_sqrt_vartime(&self) -> Self { + self.sqrt_vartime() + } + + /// Perform checked sqrt, returning a [`CtOption`] which `is_some` + /// only if the √(`self`)² == self + pub fn checked_sqrt(&self) -> CtOption { + let r = self.sqrt(); + let s = r.wrapping_mul(&r); + CtOption::new(r, ConstantTimeEq::ct_eq(self, &s)) + } + + /// Perform checked sqrt, returning a [`CtOption`] which `is_some` + /// only if the √(`self`)² == self + pub fn checked_sqrt_vartime(&self) -> CtOption { + let r = self.sqrt_vartime(); + let s = r.wrapping_mul(&r); + CtOption::new(r, ConstantTimeEq::ct_eq(self, &s)) + } +} + +#[cfg(test)] +mod tests { + use crate::{BoxedUint, Limb}; + + #[cfg(feature = "rand")] + use { + crate::CheckedMul, + rand_chacha::ChaChaRng, + rand_core::{RngCore, SeedableRng}, + }; + + #[test] + fn edge() { + assert_eq!( + BoxedUint::zero_with_precision(256).sqrt(), + BoxedUint::zero_with_precision(256) + ); + assert_eq!( + BoxedUint::one_with_precision(256).sqrt(), + BoxedUint::one_with_precision(256) + ); + let mut half = BoxedUint::zero_with_precision(256); + for i in 0..half.limbs.len() / 2 { + half.limbs[i] = Limb::MAX; + } + let u256_max = !BoxedUint::zero_with_precision(256); + assert_eq!(u256_max.sqrt(), half); + + // Test edge cases that use up the maximum number of iterations. + + // `x = (r + 1)^2 - 583`, where `r` is the expected square root. + assert_eq!( + BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192) + .unwrap() + .sqrt(), + BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192) + .unwrap(), + ); + assert_eq!( + BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192) + .unwrap() + .sqrt_vartime(), + BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192) + .unwrap() + ); + + // `x = (r + 1)^2 - 205`, where `r` is the expected square root. + assert_eq!( + BoxedUint::from_be_hex( + "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597", + 256 + ) + .unwrap() + .sqrt(), + BoxedUint::from_be_hex( + "000000000000000000000000000000008b3956339e8315cff66eb6107b610075", + 256 + ) + .unwrap() + ); + assert_eq!( + BoxedUint::from_be_hex( + "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597", + 256 + ) + .unwrap() + .sqrt_vartime(), + BoxedUint::from_be_hex( + "000000000000000000000000000000008b3956339e8315cff66eb6107b610075", + 256 + ) + .unwrap() + ); + } + + #[test] + fn edge_vartime() { + assert_eq!( + BoxedUint::zero_with_precision(256).sqrt_vartime(), + BoxedUint::zero_with_precision(256) + ); + assert_eq!( + BoxedUint::one_with_precision(256).sqrt_vartime(), + BoxedUint::one_with_precision(256) + ); + let mut half = BoxedUint::zero_with_precision(256); + for i in 0..half.limbs.len() / 2 { + half.limbs[i] = Limb::MAX; + } + let u256_max = !BoxedUint::zero_with_precision(256); + assert_eq!(u256_max.sqrt_vartime(), half); + } + + #[test] + fn simple() { + let tests = [ + (4u8, 2u8), + (9, 3), + (16, 4), + (25, 5), + (36, 6), + (49, 7), + (64, 8), + (81, 9), + (100, 10), + (121, 11), + (144, 12), + (169, 13), + ]; + for (a, e) in &tests { + let l = BoxedUint::from(*a); + let r = BoxedUint::from(*e); + assert_eq!(l.sqrt(), r); + assert_eq!(l.sqrt_vartime(), r); + assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8); + assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8); + } + } + + #[test] + fn nonsquares() { + assert_eq!(BoxedUint::from(2u8).sqrt(), BoxedUint::from(1u8)); + assert_eq!(BoxedUint::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0); + assert_eq!(BoxedUint::from(3u8).sqrt(), BoxedUint::from(1u8)); + assert_eq!(BoxedUint::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0); + assert_eq!(BoxedUint::from(5u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(6u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(7u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(8u8).sqrt(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(10u8).sqrt(), BoxedUint::from(3u8)); + } + + #[test] + fn nonsquares_vartime() { + assert_eq!(BoxedUint::from(2u8).sqrt_vartime(), BoxedUint::from(1u8)); + assert_eq!( + BoxedUint::from(2u8) + .checked_sqrt_vartime() + .is_some() + .unwrap_u8(), + 0 + ); + assert_eq!(BoxedUint::from(3u8).sqrt_vartime(), BoxedUint::from(1u8)); + assert_eq!( + BoxedUint::from(3u8) + .checked_sqrt_vartime() + .is_some() + .unwrap_u8(), + 0 + ); + assert_eq!(BoxedUint::from(5u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(6u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(7u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(8u8).sqrt_vartime(), BoxedUint::from(2u8)); + assert_eq!(BoxedUint::from(10u8).sqrt_vartime(), BoxedUint::from(3u8)); + } + + #[cfg(feature = "rand")] + #[test] + fn fuzz() { + let mut rng = ChaChaRng::from_seed([7u8; 32]); + for _ in 0..50 { + let t = rng.next_u32() as u64; + let s = BoxedUint::from(t); + let s2 = s.checked_mul(&s).unwrap(); + assert_eq!(s2.sqrt(), s); + assert_eq!(s2.sqrt_vartime(), s); + assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1); + assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1); + } + + for _ in 0..50 { + let s = BoxedUint::random(&mut rng, 512); + let mut s2 = BoxedUint::zero_with_precision(512); + s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs); + assert_eq!(s.square().sqrt(), s2); + assert_eq!(s.square().sqrt_vartime(), s2); + } + } +} diff --git a/src/uint/div_limb.rs b/src/uint/div_limb.rs index f8d47f8bc..415b5ebe1 100644 --- a/src/uint/div_limb.rs +++ b/src/uint/div_limb.rs @@ -74,6 +74,7 @@ pub const fn reciprocal(d: Word) -> Word { /// Returns `u32::MAX` if `a < b` and `0` otherwise. #[inline] const fn lt(a: u32, b: u32) -> u32 { + // TODO: Move to using ConstChoice::le let bit = (((!a) & b) | (((!a) | b) & (a.wrapping_sub(b)))) >> (u32::BITS - 1); bit.wrapping_neg() } @@ -81,6 +82,7 @@ const fn lt(a: u32, b: u32) -> u32 { /// Returns `a` if `c == 0` and `b` if `c == u32::MAX`. #[inline(always)] const fn select(a: u32, b: u32, c: u32) -> u32 { + // TODO: Move to using ConstChoice::select a ^ (c & (a ^ b)) } @@ -117,7 +119,7 @@ const fn short_div(dividend: u32, dividend_bits: u32, divisor: u32, divisor_bits /// Calculate the quotient and the remainder of the division of a wide word /// (supplied as high and low words) by `d`, with a precalculated reciprocal `v`. #[inline(always)] -const fn div2by1(u1: Word, u0: Word, reciprocal: &Reciprocal) -> (Word, Word) { +pub(crate) const fn div2by1(u1: Word, u0: Word, reciprocal: &Reciprocal) -> (Word, Word) { let d = reciprocal.divisor_normalized; debug_assert!(d >= (1 << (Word::BITS - 1))); @@ -184,6 +186,11 @@ impl Reciprocal { reciprocal: 1, } } + + /// Get the shift value + pub const fn shift(&self) -> u32 { + self.shift + } } impl ConditionallySelectable for Reciprocal { diff --git a/src/uint/encoding.rs b/src/uint/encoding.rs index b51b33c3a..2fcf7001c 100644 --- a/src/uint/encoding.rs +++ b/src/uint/encoding.rs @@ -240,7 +240,7 @@ const fn decode_nibble(src: u8) -> u16 { /// Second element of the tuple is non-zero if the `bytes` values are not in the valid range /// (0-9, a-z, A-Z). #[inline(always)] -const fn decode_hex_byte(bytes: [u8; 2]) -> (u8, u16) { +pub(crate) const fn decode_hex_byte(bytes: [u8; 2]) -> (u8, u16) { let hi = decode_nibble(bytes[0]); let lo = decode_nibble(bytes[1]); let byte = (hi << 4) | lo;