diff --git a/benches/uint.rs b/benches/uint.rs index 1d154f718..93a6ffe5e 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -160,8 +160,8 @@ fn bench_inv_mod(c: &mut Criterion) { let m = U256::random(&mut OsRng) | U256::ONE; loop { let x = U256::random(&mut OsRng); - let (_, is_some) = x.inv_odd_mod(&m); - if is_some.into() { + let inv_x = x.inv_odd_mod(&m); + if inv_x.is_some().into() { break (x, m); } } @@ -177,8 +177,8 @@ fn bench_inv_mod(c: &mut Criterion) { let m = U256::random(&mut OsRng) | U256::ONE; loop { let x = U256::random(&mut OsRng); - let (_, is_some) = x.inv_odd_mod(&m); - if is_some.into() { + let inv_x = x.inv_odd_mod(&m); + if inv_x.is_some().into() { break (x, m); } } @@ -194,8 +194,8 @@ fn bench_inv_mod(c: &mut Criterion) { let m = U256::random(&mut OsRng); loop { let x = U256::random(&mut OsRng); - let (_, is_some) = black_box(x.inv_mod(&m)); - if is_some.into() { + let inv_x = x.inv_mod(&m); + if inv_x.is_some().into() { break (x, m); } } @@ -208,6 +208,25 @@ fn bench_inv_mod(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, bench_shl, bench_shr, bench_division, bench_inv_mod); +fn bench_sqrt(c: &mut Criterion) { + let mut group = c.benchmark_group("sqrt"); + + group.bench_function("sqrt, U256", |b| { + b.iter_batched( + || U256::random(&mut OsRng), + |x| x.sqrt(), + BatchSize::SmallInput, + ) + }); +} + +criterion_group!( + benches, + bench_shl, + bench_shr, + bench_division, + bench_inv_mod, + bench_sqrt +); criterion_main!(benches); diff --git a/src/ct_choice.rs b/src/const_choice.rs similarity index 53% rename from src/ct_choice.rs rename to src/const_choice.rs index 40c49dee7..8fc778964 100644 --- a/src/ct_choice.rs +++ b/src/const_choice.rs @@ -1,6 +1,6 @@ -use subtle::Choice; +use subtle::{Choice, CtOption}; -use crate::Word; +use crate::{NonZero, Uint, Word}; /// A boolean value returned by constant-time `const fn`s. // TODO: should be replaced by `subtle::Choice` or `CtOption` @@ -71,10 +71,19 @@ impl ConstChoice { /// Returns the truthy value if `x < y`, and the falsy value otherwise. #[inline] pub(crate) const fn from_word_lt(x: Word, y: Word) -> Self { + // See "Hacker's Delight" 2nd ed, section 2-12 (Comparison predicates) let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (Word::BITS - 1); Self::from_word_lsb(bit) } + /// Returns the truthy value if `x > y`, and the falsy value otherwise. + #[inline] + pub(crate) const fn from_word_gt(x: Word, y: Word) -> Self { + // See "Hacker's Delight" 2nd ed, section 2-12 (Comparison predicates) + let bit = (((!y) & x) | (((!y) | x) & (y.wrapping_sub(x)))) >> (Word::BITS - 1); + Self::from_word_lsb(bit) + } + /// Returns the truthy value if `x < y`, and the falsy value otherwise. #[inline] pub(crate) const fn from_u32_lt(x: u32, y: u32) -> Self { @@ -147,6 +156,7 @@ impl ConstChoice { } impl From for Choice { + #[inline] fn from(choice: ConstChoice) -> Self { Choice::from(choice.to_u8()) } @@ -164,11 +174,144 @@ impl PartialEq for ConstChoice { } } +/// An equivalent of `subtle::CtOption` usable in a `const fn` context. +#[derive(Debug, Clone)] +pub struct ConstCtOption { + value: T, + is_some: ConstChoice, +} + +impl ConstCtOption { + #[inline] + pub(crate) const fn new(value: T, is_some: ConstChoice) -> Self { + Self { value, is_some } + } + + #[inline] + pub(crate) const fn some(value: T) -> Self { + Self { + value, + is_some: ConstChoice::TRUE, + } + } + + #[inline] + pub(crate) const fn none(dummy_value: T) -> Self { + Self { + value: dummy_value, + is_some: ConstChoice::FALSE, + } + } + + /// Returns a reference to the contents of this structure. + /// + /// **Note:** if the second element is `None`, the first value may take any value. + #[inline] + pub(crate) const fn components_ref(&self) -> (&T, ConstChoice) { + // Since Rust is not smart enough to tell that we would be moving the value, + // and hence no destructors will be called, we have to return a reference instead. + // See https://github.com/rust-lang/rust/issues/66753 + (&self.value, self.is_some) + } + + /// Returns a true [`ConstChoice`] if this value is `Some`. + #[inline] + pub const fn is_some(&self) -> ConstChoice { + self.is_some + } + + /// Returns a true [`ConstChoice`] if this value is `None`. + #[inline] + pub const fn is_none(&self) -> ConstChoice { + self.is_some.not() + } + + /// This returns the underlying value but panics if it is not `Some`. + #[inline] + pub fn unwrap(self) -> T { + assert!(self.is_some.is_true_vartime()); + self.value + } +} + +impl From> for CtOption { + #[inline] + fn from(value: ConstCtOption) -> Self { + CtOption::new(value.value, value.is_some.into()) + } +} + +// Need specific implementations to work around the +// "destructors cannot be evaluated at compile-time" error +// See https://github.com/rust-lang/rust/issues/66753 + +impl ConstCtOption> { + /// This returns the underlying value if it is `Some` or the provided value otherwise. + #[inline] + pub const fn unwrap_or(self, def: Uint) -> Uint { + Uint::select(&def, &self.value, self.is_some) + } + + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> Uint { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + +impl ConstCtOption<(Uint, Uint)> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> (Uint, Uint) { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + +impl ConstCtOption>> { + /// Returns the contained value, consuming the `self` value. + /// + /// # Panics + /// + /// Panics if the value is none with a custom panic message provided by + /// `msg`. + #[inline] + pub const fn expect(self, msg: &str) -> NonZero> { + assert!(self.is_some.is_true_vartime(), "{}", msg); + self.value + } +} + #[cfg(test)] mod tests { use super::ConstChoice; use crate::Word; + #[test] + fn from_word_lt() { + assert_eq!(ConstChoice::from_word_lt(4, 5), ConstChoice::TRUE); + assert_eq!(ConstChoice::from_word_lt(5, 5), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_word_lt(6, 5), ConstChoice::FALSE); + } + + #[test] + fn from_word_gt() { + assert_eq!(ConstChoice::from_word_gt(4, 5), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_word_gt(5, 5), ConstChoice::FALSE); + assert_eq!(ConstChoice::from_word_gt(6, 5), ConstChoice::TRUE); + } + #[test] fn select() { let a: Word = 1; diff --git a/src/lib.rs b/src/lib.rs index 5b806322a..eae9d9ad6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,7 +167,7 @@ pub mod modular; #[cfg(feature = "generic-array")] mod array; mod checked; -mod ct_choice; +mod const_choice; mod limb; mod non_zero; mod traits; @@ -176,7 +176,7 @@ mod wrapping; pub use crate::{ checked::Checked, - ct_choice::ConstChoice, + const_choice::{ConstChoice, ConstCtOption}, limb::{Limb, WideWord, Word}, non_zero::NonZero, traits::*, diff --git a/src/limb/cmp.rs b/src/limb/cmp.rs index 834270c5e..e1295898c 100644 --- a/src/limb/cmp.rs +++ b/src/limb/cmp.rs @@ -54,16 +54,14 @@ impl ConstantTimeEq for Limb { impl ConstantTimeGreater for Limb { #[inline] fn ct_gt(&self, other: &Self) -> Choice { - let borrow = other.sbb(*self, Limb::ZERO).1; - Choice::from(borrow.0 as u8 & 1) + ConstChoice::from_word_gt(self.0, other.0).into() } } impl ConstantTimeLess for Limb { #[inline] fn ct_lt(&self, other: &Self) -> Choice { - let borrow = self.sbb(*other, Limb::ZERO).1; - Choice::from(borrow.0 as u8 & 1) + ConstChoice::from_word_lt(self.0, other.0).into() } } diff --git a/src/modular/boxed_residue/mul.rs b/src/modular/boxed_residue/mul.rs index 8b47c7ec1..ccb83c7e7 100644 --- a/src/modular/boxed_residue/mul.rs +++ b/src/modular/boxed_residue/mul.rs @@ -301,7 +301,7 @@ fn sub_vv(z: &mut [Limb], x: &[Limb], y: &[Limb]) -> Limb { for (i, (&xi, &yi)) in x.iter().zip(y.iter()).enumerate().take(z.len()) { let zi = xi.wrapping_sub(yi).wrapping_sub(c); z[i] = zi; - // see "Hacker's Delight", section 2-12 (overflow detection) + // See "Hacker's Delight" 2nd ed, section 2-13 (Overflow detection) c = ((yi & !xi) | ((yi | !xi) & zi)) >> (Word::BITS - 1) } diff --git a/src/modular/dyn_residue.rs b/src/modular/dyn_residue.rs index 7fb8ef7d3..51e65158b 100644 --- a/src/modular/dyn_residue.rs +++ b/src/modular/dyn_residue.rs @@ -50,8 +50,9 @@ impl DynResidueParams { let r = Uint::MAX.rem(&nz_modulus).wrapping_add(&Uint::ONE); let r2 = Uint::rem_wide(r.square_wide(), &nz_modulus); + let maybe_inverse = modulus.inv_mod2k_vartime(Word::BITS); // If the inverse exists, it means the modulus is odd. - let (inv_mod_limb, modulus_is_odd) = modulus.inv_mod2k_vartime(Word::BITS); + let (inv_mod_limb, modulus_is_odd) = maybe_inverse.components_ref(); let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod_limb.limbs[0].0)); let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv); diff --git a/src/modular/dyn_residue/inv.rs b/src/modular/dyn_residue/inv.rs index 0d5d6e4f7..9c4ba8dfd 100644 --- a/src/modular/dyn_residue/inv.rs +++ b/src/modular/dyn_residue/inv.rs @@ -4,7 +4,7 @@ use super::{DynResidue, DynResidueParams}; use crate::{ modular::{inv::inv_montgomery_form, BernsteinYangInverter}, traits::Invert, - ConstChoice, Inverter, PrecomputeInverter, PrecomputeInverterWithAdjuster, Uint, + ConstCtOption, Inverter, PrecomputeInverter, PrecomputeInverterWithAdjuster, Uint, }; use core::fmt; use subtle::CtOption; @@ -14,28 +14,28 @@ impl DynResidue { /// I.e. `self * self^-1 = 1`. /// If the number was invertible, the second element of the tuple is the truthy value, /// otherwise it is the falsy value (in which case the first element's value is unspecified). - pub const fn invert(&self) -> (Self, ConstChoice) { - let (montgomery_form, is_some) = inv_montgomery_form( + pub const fn invert(&self) -> ConstCtOption { + let maybe_inverse = inv_montgomery_form( &self.montgomery_form, &self.residue_params.modulus, &self.residue_params.r3, self.residue_params.mod_neg_inv, ); + let (montgomery_form, is_some) = maybe_inverse.components_ref(); let value = Self { - montgomery_form, + montgomery_form: *montgomery_form, residue_params: self.residue_params, }; - (value, is_some) + ConstCtOption::new(value, is_some) } } impl Invert for DynResidue { type Output = CtOption; fn invert(&self) -> Self::Output { - let (value, is_some) = self.invert(); - CtOption::new(value, is_some.into()) + self.invert().into() } } @@ -114,7 +114,7 @@ mod tests { U256::from_be_hex("77117F1273373C26C700D076B3F780074D03339F56DD0EFB60E7F58441FD3685"); let x_mod = DynResidue::new(&x, params); - let (inv, _is_some) = x_mod.invert(); + let inv = x_mod.invert().unwrap(); let res = x_mod * inv; assert_eq!(res.retrieve(), U256::ONE); diff --git a/src/modular/inv.rs b/src/modular/inv.rs index c57438f9f..7f89ae4fc 100644 --- a/src/modular/inv.rs +++ b/src/modular/inv.rs @@ -1,13 +1,14 @@ -use crate::{modular::reduction::montgomery_reduction, ConstChoice, Limb, Uint}; +use crate::{modular::reduction::montgomery_reduction, ConstCtOption, Limb, Uint}; pub const fn inv_montgomery_form( x: &Uint, modulus: &Uint, r3: &Uint, mod_neg_inv: Limb, -) -> (Uint, ConstChoice) { - let (inverse, is_some) = x.inv_odd_mod(modulus); - ( +) -> ConstCtOption> { + let maybe_inverse = x.inv_odd_mod(modulus); + let (inverse, is_some) = maybe_inverse.components_ref(); + ConstCtOption::new( montgomery_reduction(&inverse.split_mul(r3), modulus, mod_neg_inv), is_some, ) diff --git a/src/modular/residue/inv.rs b/src/modular/residue/inv.rs index 2eee1fe9f..7d9bf7e0f 100644 --- a/src/modular/residue/inv.rs +++ b/src/modular/residue/inv.rs @@ -3,7 +3,7 @@ use super::{Residue, ResidueParams}; use crate::{ modular::{inv::inv_montgomery_form, BernsteinYangInverter}, - ConstChoice, Invert, Inverter, NonZero, PrecomputeInverter, Uint, + ConstChoice, ConstCtOption, Invert, Inverter, NonZero, PrecomputeInverter, Uint, }; use core::{fmt, marker::PhantomData}; use subtle::CtOption; @@ -13,28 +13,28 @@ impl, const LIMBS: usize> Residue { /// I.e. `self * self^-1 = 1`. /// If the number was invertible, the second element of the tuple is the truthy value, /// otherwise it is the falsy value (in which case the first element's value is unspecified). - pub const fn invert(&self) -> (Self, ConstChoice) { - let (montgomery_form, is_some) = inv_montgomery_form( + pub const fn invert(&self) -> ConstCtOption { + let maybe_inverse = inv_montgomery_form( &self.montgomery_form, &MOD::MODULUS.0, &MOD::R3, MOD::MOD_NEG_INV, ); + let (montgomery_form, is_some) = maybe_inverse.components_ref(); let value = Self { - montgomery_form, + montgomery_form: *montgomery_form, phantom: PhantomData, }; - (value, is_some) + ConstCtOption::new(value, is_some) } } impl, const LIMBS: usize> Invert for Residue { type Output = CtOption; fn invert(&self) -> Self::Output { - let (value, is_some) = self.invert(); - CtOption::new(value, is_some.into()) + self.invert().into() } } @@ -42,7 +42,8 @@ impl, const LIMBS: usize> Invert for NonZero Self::Output { // Always succeeds for a non-zero argument - let (value, _is_some) = self.as_ref().invert(); + let value = self.as_ref().invert().unwrap(); + // An inverse is necessarily non-zero NonZero::new(value).unwrap() } } @@ -136,7 +137,7 @@ mod tests { U256::from_be_hex("77117F1273373C26C700D076B3F780074D03339F56DD0EFB60E7F58441FD3685"); let x_mod = const_residue!(x, Modulus); - let (inv, _is_some) = x_mod.invert(); + let inv = x_mod.invert().unwrap(); let res = x_mod * inv; assert_eq!(res.retrieve(), U256::ONE); diff --git a/src/modular/residue/macros.rs b/src/modular/residue/macros.rs index a25f4aa01..4be1e30ee 100644 --- a/src/modular/residue/macros.rs +++ b/src/modular/residue/macros.rs @@ -29,7 +29,7 @@ macro_rules! impl_modulus { } // Can unwrap `NonZero::const_new()` here since `res` was asserted to be odd. - $crate::NonZero::<$uint_type>::const_new(res).0 + $crate::NonZero::<$uint_type>::const_new(res).expect("modulus ensured non-zero") }; const R: $uint_type = $crate::Uint::MAX @@ -41,7 +41,7 @@ macro_rules! impl_modulus { Self::MODULUS .as_ref() .inv_mod2k_vartime($crate::Word::BITS) - .0 + .expect("modulus ensured odd") .as_limbs()[0] .0, ), diff --git a/src/non_zero.rs b/src/non_zero.rs index 3c00148ee..e0691bc92 100644 --- a/src/non_zero.rs +++ b/src/non_zero.rs @@ -1,6 +1,6 @@ //! Wrapper type for non-zero integers. -use crate::{Bounded, ConstChoice, Constants, Encoding, Limb, Uint, Zero}; +use crate::{Bounded, ConstCtOption, Constants, Encoding, Limb, Uint, Zero}; use core::{ fmt, num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8}, @@ -27,16 +27,16 @@ pub struct NonZero(pub(crate) T); impl NonZero { /// Creates a new non-zero limb in a const context. /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise. - pub const fn const_new(n: Limb) -> (Self, ConstChoice) { - (Self(n), n.is_nonzero()) + pub const fn const_new(n: Limb) -> ConstCtOption { + ConstCtOption::new(Self(n), n.is_nonzero()) } } impl NonZero> { /// Creates a new non-zero integer in a const context. /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise. - pub const fn const_new(n: Uint) -> (Self, ConstChoice) { - (Self(n), n.is_nonzero()) + pub const fn const_new(n: Uint) -> ConstCtOption { + ConstCtOption::new(Self(n), n.is_nonzero()) } } diff --git a/src/uint/div.rs b/src/uint/div.rs index 5440914b2..074d5b6d7 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -123,7 +123,8 @@ impl Uint { let (mut lower, mut upper) = lower_upper; // Factor of the modulus, split into two halves - let (mut c, _overflow) = Self::overflowing_shl_vartime_wide((rhs.0, Uint::ZERO), bd); + let mut c = Self::overflowing_shl_vartime_wide((rhs.0, Uint::ZERO), bd) + .expect("shift within range"); loop { let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO); @@ -135,8 +136,7 @@ impl Uint { break; } bd -= 1; - let (new_c, _overflow) = Self::overflowing_shr_vartime_wide(c, 1); - c = new_c; + c = Self::overflowing_shr_vartime_wide(c, 1).expect("shift within range"); } lower @@ -201,8 +201,7 @@ impl Uint { /// /// Panics if `rhs == 0`. pub const fn wrapping_rem(&self, rhs: &Self) -> Self { - let (nz_rhs, c) = NonZero::::const_new(*rhs); - assert!(c.is_true_vartime(), "modulo zero"); + let nz_rhs = NonZero::::const_new(*rhs).expect("non-zero divisor"); self.rem(&nz_rhs) } @@ -624,8 +623,9 @@ mod tests { fn div() { let mut rng = ChaChaRng::from_seed([7u8; 32]); for _ in 0..25 { - let (num, _) = U256::random(&mut rng).overflowing_shr_vartime(128); - let den = NonZero::new(U256::random(&mut rng).overflowing_shr_vartime(128).0).unwrap(); + let num = U256::random(&mut rng).overflowing_shr_vartime(128).unwrap(); + let den = + NonZero::new(U256::random(&mut rng).overflowing_shr_vartime(128).unwrap()).unwrap(); let n = num.checked_mul(den.as_ref()); if n.is_some().into() { let (q, _) = n.unwrap().div_rem(&den); @@ -714,7 +714,7 @@ mod tests { for _ in 0..25 { let num = U256::random(&mut rng); let k = rng.next_u32() % 256; - let (den, _) = U256::ONE.overflowing_shl_vartime(k); + let den = U256::ONE.overflowing_shl_vartime(k).unwrap(); let a = num.rem2k(k); let e = num.wrapping_rem(&den); diff --git a/src/uint/inv_mod.rs b/src/uint/inv_mod.rs index c0ba8948c..1f5ff5fa8 100644 --- a/src/uint/inv_mod.rs +++ b/src/uint/inv_mod.rs @@ -1,5 +1,5 @@ use super::Uint; -use crate::ConstChoice; +use crate::{ConstChoice, ConstCtOption}; impl Uint { /// Computes 1/`self` mod `2^k`. @@ -8,7 +8,7 @@ impl Uint { /// If the inverse does not exist (`k > 0` and `self` is even), /// returns `ConstChoice::FALSE` as the second element of the tuple, /// otherwise returns `ConstChoice::TRUE`. - pub const fn inv_mod2k_vartime(&self, k: u32) -> (Self, ConstChoice) { + pub const fn inv_mod2k_vartime(&self, k: u32) -> ConstCtOption { // Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k" // by Sadiel de la Fe and Carles Ferrer. // See . @@ -30,13 +30,15 @@ impl Uint { // b_{i+1} = (b_i - a * X_i) / 2 b = Self::select(&b, &b.wrapping_sub(self), x_i_choice).shr1(); // Store the X_i bit in the result (x = x | (1 << X_i)) - let (shifted, _overflow) = Uint::from_word(x_i).overflowing_shl_vartime(i); + let shifted = Uint::from_word(x_i) + .overflowing_shl_vartime(i) + .expect("shift within range"); x = x.bitor(&shifted); i += 1; } - (x, is_some) + ConstCtOption::new(x, is_some) } /// Computes 1/`self` mod `2^k`. @@ -44,7 +46,7 @@ impl Uint { /// If the inverse does not exist (`k > 0` and `self` is even), /// returns `ConstChoice::FALSE` as the second element of the tuple, /// otherwise returns `ConstChoice::TRUE`. - pub const fn inv_mod2k(&self, k: u32) -> (Self, ConstChoice) { + pub const fn inv_mod2k(&self, k: u32) -> ConstCtOption { // This is the same algorithm as in `inv_mod2k_vartime()`, // but made constant-time w.r.t `k` as well. @@ -74,7 +76,7 @@ impl Uint { i += 1; } - (x, is_some) + ConstCtOption::new(x, is_some) } /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. @@ -93,7 +95,7 @@ impl Uint { modulus: &Self, bits: u32, modulus_bits: u32, - ) -> (Self, ConstChoice) { + ) -> ConstCtOption { let mut a = *self; let mut u = Uint::ONE; @@ -146,52 +148,58 @@ impl Uint { .or(a.is_nonzero().not()) .is_true_vartime()); - (v, Uint::eq(&b, &Uint::ONE).and(modulus_is_odd)) + ConstCtOption::new(v, Uint::eq(&b, &Uint::ONE).and(modulus_is_odd)) } /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd. /// Returns `(inverse, ConstChoice::TRUE)` if an inverse exists, /// otherwise `(undefined, ConstChoice::FALSE)`. - pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, ConstChoice) { + pub const fn inv_odd_mod(&self, modulus: &Self) -> ConstCtOption { self.inv_odd_mod_bounded(modulus, Uint::::BITS, Uint::::BITS) } /// Computes the multiplicative inverse of `self` mod `modulus`. /// Returns `(inverse, ConstChoice::TRUE)` if an inverse exists, /// otherwise `(undefined, ConstChoice::FALSE)`. - pub const fn inv_mod(&self, modulus: &Self) -> (Self, ConstChoice) { + pub const fn inv_mod(&self, modulus: &Self) -> ConstCtOption { // Decompose `modulus = s * 2^k` where `s` is odd let k = modulus.trailing_zeros(); - let (s, _overflow) = modulus.overflowing_shr(k); + let s = modulus.overflowing_shr(k).unwrap_or(Self::ZERO); // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses. // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1` - let (a, a_is_some) = self.inv_odd_mod(&s); - let (b, b_is_some) = self.inv_mod2k(k); + let maybe_a = self.inv_odd_mod(&s); + let maybe_b = self.inv_mod2k(k); + let is_some = maybe_a.is_some().and(maybe_b.is_some()); + + // Unwrap to avoid mapping through ConstCtOptions. + // if `a` or `b` don't exist, the returned ConstCtOption will be None anyway. + let a = maybe_a.unwrap_or(Uint::ZERO); + let b = maybe_b.unwrap_or(Uint::ZERO); // Restore from RNS: // self^{-1} = a mod s = b mod 2^k // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k) // (essentially one step of the Garner's algorithm for recovery from RNS). - let (m_odd_inv, _is_some) = s.inv_mod2k(k); // `s` is odd, so this always exists + // `s` is odd, so this always exists + let m_odd_inv = s.inv_mod2k(k).expect("inverse mod 2^k exists"); // This part is mod 2^k - // Will not overflow since `modulus` is nonzero, and therefore `k < BITS`. - let (shifted, _overflow) = Uint::ONE.overflowing_shl(k); + let shifted = Uint::ONE.overflowing_shl(k).unwrap_or(Self::ZERO); let mask = shifted.wrapping_sub(&Uint::ONE); let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask); // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`, // so `a + s * t <= s * 2^k - 1 == modulus - 1`. let result = a.wrapping_add(&s.wrapping_mul(&t)); - (result, a_is_some.and(b_is_some)) + ConstCtOption::new(result, is_some) } } #[cfg(test)] mod tests { - use crate::{ConstChoice, U1024, U256, U64}; + use crate::{U1024, U256, U64}; #[test] fn inv_mod2k() { @@ -199,25 +207,21 @@ mod tests { U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f"); let e = U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf"); - let (a, is_some) = v.inv_mod2k(256); + let a = v.inv_mod2k(256).unwrap(); assert_eq!(e, a); - assert_eq!(is_some, ConstChoice::TRUE); - let (a, is_some) = v.inv_mod2k_vartime(256); + let a = v.inv_mod2k_vartime(256).unwrap(); assert_eq!(e, a); - assert_eq!(is_some, ConstChoice::TRUE); let v = U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"); let e = U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1"); - let (a, is_some) = v.inv_mod2k(256); + let a = v.inv_mod2k(256).unwrap(); assert_eq!(e, a); - assert_eq!(is_some, ConstChoice::TRUE); - let (a, is_some) = v.inv_mod2k_vartime(256); + let a = v.inv_mod2k_vartime(256).unwrap(); assert_eq!(e, a); - assert_eq!(is_some, ConstChoice::TRUE); // Check that even if the number is >= 2^k, the inverse is still correct. @@ -225,27 +229,24 @@ mod tests { U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"); let e = U256::from_be_hex("0000000000000000000000000000000000000000034613dbb4f20099aa774ec1"); - let (a, is_some) = v.inv_mod2k(90); + let a = v.inv_mod2k(90).unwrap(); assert_eq!(e, a); - assert_eq!(is_some, ConstChoice::TRUE); - let (a, is_some) = v.inv_mod2k_vartime(90); + let a = v.inv_mod2k_vartime(90).unwrap(); assert_eq!(e, a); - assert_eq!(is_some, ConstChoice::TRUE); // An inverse of an even number does not exist. - let (_a, is_some) = U256::from(10u64).inv_mod2k(4); - assert_eq!(is_some, ConstChoice::FALSE); + let a = U256::from(10u64).inv_mod2k(4); + assert!(a.is_none().is_true_vartime()); - let (_a, is_some) = U256::from(10u64).inv_mod2k_vartime(4); - assert_eq!(is_some, ConstChoice::FALSE); + let a = U256::from(10u64).inv_mod2k_vartime(4); + assert!(a.is_none().is_true_vartime()); // A degenerate case. An inverse mod 2^0 == 1 always exists even for even numbers. - let (a, is_some) = U256::from(10u64).inv_mod2k_vartime(0); + let a = U256::from(10u64).inv_mod2k_vartime(0).unwrap(); assert_eq!(a, U256::ZERO); - assert_eq!(is_some, ConstChoice::TRUE); } #[test] @@ -269,17 +270,15 @@ mod tests { "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336" ]); - let (res, is_some) = a.inv_odd_mod(&m); - assert!(is_some.is_true_vartime()); + let res = a.inv_odd_mod(&m).unwrap(); assert_eq!(res, expected); - // Check that trying to pass an even modulus causes `is_some` to be falsy - let (_res, is_some) = a.inv_odd_mod(&(m.wrapping_add(&U1024::ONE))); - assert!(!is_some.is_true_vartime()); + // Check that trying to pass an even modulus results in `None` + let res = a.inv_odd_mod(&(m.wrapping_add(&U1024::ONE))); + assert!(res.is_none().is_true_vartime()); // Even though it is less efficient, it still works - let (res, is_some) = a.inv_mod(&m); - assert!(is_some.is_true_vartime()); + let res = a.inv_mod(&m).unwrap(); assert_eq!(res, expected); } @@ -295,8 +294,8 @@ mod tests { let m = p1.wrapping_mul(&p2); // `m` is a multiple of `p1`, so no inverse exists - let (_res, is_some) = p1.inv_odd_mod(&m); - assert!(!is_some.is_true_vartime()); + let res = p1.inv_odd_mod(&m); + assert!(res.is_none().is_true_vartime()); } #[test] @@ -320,8 +319,7 @@ mod tests { "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D", ]); - let (res, is_some) = a.inv_mod(&m); - assert!(is_some.is_true_vartime()); + let res = a.inv_mod(&m).unwrap(); assert_eq!(res, expected); } @@ -340,7 +338,7 @@ mod tests { "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767" ]); - let (res, is_some) = a.inv_odd_mod_bounded(&m, 768, 512); + let res = a.inv_odd_mod_bounded(&m, 768, 512).unwrap(); let expected = U1024::from_be_hex(concat![ "0000000000000000000000000000000000000000000000000000000000000000", @@ -348,7 +346,6 @@ mod tests { "0DCC94E2FE509E6EBBA0825645A38E73EF85D5927C79C1AD8FFE7C8DF9A822FA", "09EB396A21B1EF05CBE51E1A8EF284EF01EBDD36A9A4EA17039D8EEFDD934768" ]); - assert!(is_some.is_true_vartime()); assert_eq!(res, expected); } @@ -357,9 +354,7 @@ mod tests { let a = U64::from(3u64); let m = U64::from(13u64); - let (res, is_some) = a.inv_odd_mod(&m); - - assert!(is_some.is_true_vartime()); + let res = a.inv_odd_mod(&m).unwrap(); assert_eq!(U64::from(9u64), res); } @@ -368,8 +363,7 @@ mod tests { let a = U64::from(14u64); let m = U64::from(49u64); - let (_res, is_some) = a.inv_odd_mod(&m); - - assert!(!is_some.is_true_vartime()); + let res = a.inv_odd_mod(&m); + assert!(res.is_none().is_true_vartime()); } } diff --git a/src/uint/mul.rs b/src/uint/mul.rs index 3a46e42cd..84f745174 100644 --- a/src/uint/mul.rs +++ b/src/uint/mul.rs @@ -137,7 +137,7 @@ impl Uint { // Double the current result, this accounts for the other half of the multiplication grid. // TODO: The top word is empty so we can also use a special purpose shl. - (lo, hi) = Self::overflowing_shl_vartime_wide((lo, hi), 1).0; + (lo, hi) = Self::overflowing_shl_vartime_wide((lo, hi), 1).expect("shift within range"); // Handle the diagonal of the multiplication grid, which finishes the multiplication grid. let mut carry = Limb::ZERO; diff --git a/src/uint/shl.rs b/src/uint/shl.rs index ed972f181..9d911fe3d 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -1,6 +1,6 @@ //! [`Uint`] bitwise left shift operations. -use crate::{ConstChoice, Limb, Uint, Word, WrappingShl}; +use crate::{ConstChoice, ConstCtOption, Limb, Uint, Word, WrappingShl}; use core::ops::{Shl, ShlAssign}; impl Uint { @@ -8,19 +8,15 @@ impl Uint { /// /// Panics if `shift >= Self::BITS`. pub const fn shl(&self, shift: u32) -> Self { - let (result, overflow) = self.overflowing_shl(shift); - assert!( - !overflow.is_true_vartime(), - "attempt to shift left with overflow" - ); - result + self.overflowing_shl(shift) + .expect("`shift` within the bit size of the integer") } /// Computes `self << shift`. /// /// If `shift >= Self::BITS`, returns zero as the first tuple element, /// and `ConstChoice::TRUE` as the second element. - pub const fn overflowing_shl(&self, shift: u32) -> (Self, ConstChoice) { + pub const fn overflowing_shl(&self, shift: u32) -> ConstCtOption { // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < BITS`). let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); @@ -30,11 +26,17 @@ impl Uint { let mut i = 0; while i < shift_bits { let bit = ConstChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::select(&result, &result.overflowing_shl_vartime(1 << i).0, bit); + result = Uint::select( + &result, + &result + .overflowing_shl_vartime(1 << i) + .expect("shift within range"), + bit, + ); i += 1; } - (Uint::select(&result, &Self::ZERO, overflow), overflow) + ConstCtOption::new(Uint::select(&result, &Self::ZERO, overflow), overflow.not()) } /// Computes `self << shift`. @@ -47,11 +49,11 @@ impl Uint { /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn overflowing_shl_vartime(&self, shift: u32) -> (Self, ConstChoice) { + pub const fn overflowing_shl_vartime(&self, shift: u32) -> ConstCtOption { let mut limbs = [Limb::ZERO; LIMBS]; if shift >= Self::BITS { - return (Self::ZERO, ConstChoice::TRUE); + return ConstCtOption::none(Self::ZERO); } let shift_num = (shift / Limb::BITS) as usize; @@ -64,7 +66,7 @@ impl Uint { } if rem == 0 { - return (Self { limbs }, ConstChoice::FALSE); + return ConstCtOption::some(Self { limbs }); } let mut carry = Limb::ZERO; @@ -78,7 +80,7 @@ impl Uint { i += 1; } - (Self { limbs }, ConstChoice::FALSE) + ConstCtOption::some(Self { limbs }) } /// Computes a left shift on a wide input as `(lo, hi)`. @@ -94,31 +96,41 @@ impl Uint { pub const fn overflowing_shl_vartime_wide( lower_upper: (Self, Self), shift: u32, - ) -> ((Self, Self), ConstChoice) { + ) -> ConstCtOption<(Self, Self)> { let (lower, upper) = lower_upper; if shift >= 2 * Self::BITS { - ((Self::ZERO, Self::ZERO), ConstChoice::TRUE) + ConstCtOption::none((Self::ZERO, Self::ZERO)) } else if shift >= Self::BITS { - let (upper, _) = lower.overflowing_shl_vartime(shift - Self::BITS); - ((Self::ZERO, upper), ConstChoice::FALSE) + let upper = lower + .overflowing_shl_vartime(shift - Self::BITS) + .expect("shift within range"); + ConstCtOption::some((Self::ZERO, upper)) } else { - let (new_lower, _) = lower.overflowing_shl_vartime(shift); - let (upper_lo, _) = lower.overflowing_shr_vartime(Self::BITS - shift); - let (upper_hi, _) = upper.overflowing_shl_vartime(shift); - ((new_lower, upper_lo.bitor(&upper_hi)), ConstChoice::FALSE) + let new_lower = lower + .overflowing_shl_vartime(shift) + .expect("shift within range"); + let upper_lo = lower + .overflowing_shr_vartime(Self::BITS - shift) + .expect("shift within range"); + let upper_hi = upper + .overflowing_shl_vartime(shift) + .expect("shift within range"); + ConstCtOption::some((new_lower, upper_lo.bitor(&upper_hi))) } } /// Computes `self << shift` in a panic-free manner, masking off bits of `shift` which would cause the shift to /// exceed the type's width. pub const fn wrapping_shl(&self, shift: u32) -> Self { - self.overflowing_shl(shift).0 + self.overflowing_shl(shift % Self::BITS) + .expect("shift within range") } /// Computes `self << shift` in variable-time in a panic-free manner, masking off bits of `shift` which would cause /// the shift to exceed the type's width. pub const fn wrapping_shl_vartime(&self, shift: u32) -> Self { - self.overflowing_shl_vartime(shift).0 + self.overflowing_shl_vartime(shift % Self::BITS) + .expect("shift within range") } /// Computes `self << shift` where `0 <= shift < Limb::BITS`, @@ -208,7 +220,7 @@ impl WrappingShl for Uint { #[cfg(test)] mod tests { - use crate::{ConstChoice, Limb, Uint, U128, U256}; + use crate::{Limb, Uint, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -259,15 +271,12 @@ mod tests { #[test] fn shl256_const() { - assert_eq!(N.overflowing_shl(256), (U256::ZERO, ConstChoice::TRUE)); - assert_eq!( - N.overflowing_shl_vartime(256), - (U256::ZERO, ConstChoice::TRUE) - ); + assert!(N.overflowing_shl(256).is_none().is_true_vartime()); + assert!(N.overflowing_shl_vartime(256).is_none().is_true_vartime()); } #[test] - #[should_panic(expected = "attempt to shift left with overflow")] + #[should_panic(expected = "`shift` within the bit size of the integer")] fn shl256() { let _ = N << 256; } @@ -280,31 +289,29 @@ mod tests { #[test] fn shl_wide_1_1_128() { assert_eq!( - Uint::overflowing_shl_vartime_wide((U128::ONE, U128::ONE), 128), - ((U128::ZERO, U128::ONE), ConstChoice::FALSE) + Uint::overflowing_shl_vartime_wide((U128::ONE, U128::ONE), 128).unwrap(), + (U128::ZERO, U128::ONE) ); assert_eq!( - Uint::overflowing_shl_vartime_wide((U128::ONE, U128::ONE), 128), - ((U128::ZERO, U128::ONE), ConstChoice::FALSE) + Uint::overflowing_shl_vartime_wide((U128::ONE, U128::ONE), 128).unwrap(), + (U128::ZERO, U128::ONE) ); } #[test] fn shl_wide_max_0_1() { assert_eq!( - Uint::overflowing_shl_vartime_wide((U128::MAX, U128::ZERO), 1), - ( - (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE), - ConstChoice::FALSE - ) + Uint::overflowing_shl_vartime_wide((U128::MAX, U128::ZERO), 1).unwrap(), + (U128::MAX.sbb(&U128::ONE, Limb::ZERO).0, U128::ONE) ); } #[test] fn shl_wide_max_max_256() { - assert_eq!( - Uint::overflowing_shl_vartime_wide((U128::MAX, U128::MAX), 256), - ((U128::ZERO, U128::ZERO), ConstChoice::TRUE) + assert!( + Uint::overflowing_shl_vartime_wide((U128::MAX, U128::MAX), 256) + .is_none() + .is_true_vartime(), ); } } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 54ba0fd17..5b5305ced 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -1,6 +1,6 @@ //! [`Uint`] bitwise right shift operations. -use crate::{ConstChoice, Limb, Uint, WrappingShr}; +use crate::{ConstChoice, ConstCtOption, Limb, Uint, WrappingShr}; use core::ops::{Shr, ShrAssign}; impl Uint { @@ -8,19 +8,15 @@ impl Uint { /// /// Panics if `shift >= Self::BITS`. pub const fn shr(&self, shift: u32) -> Self { - let (result, overflow) = self.overflowing_shr(shift); - assert!( - !overflow.is_true_vartime(), - "attempt to shift right with overflow" - ); - result + self.overflowing_shr(shift) + .expect("`shift` within the bit size of the integer") } /// Computes `self >> shift`. /// /// If `shift >= Self::BITS`, returns zero as the first tuple element, /// and `ConstChoice::TRUE` as the second element. - pub const fn overflowing_shr(&self, shift: u32) -> (Self, ConstChoice) { + pub const fn overflowing_shr(&self, shift: u32) -> ConstCtOption { // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` // (which lies in range `0 <= shift < BITS`). let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); @@ -30,11 +26,17 @@ impl Uint { let mut i = 0; while i < shift_bits { let bit = ConstChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::select(&result, &result.overflowing_shr_vartime(1 << i).0, bit); + result = Uint::select( + &result, + &result + .overflowing_shr_vartime(1 << i) + .expect("shift within range"), + bit, + ); i += 1; } - (Uint::select(&result, &Self::ZERO, overflow), overflow) + ConstCtOption::new(Uint::select(&result, &Self::ZERO, overflow), overflow.not()) } /// Computes `self >> shift`. @@ -47,11 +49,11 @@ impl Uint { /// When used with a fixed `shift`, this function is constant-time with respect /// to `self`. #[inline(always)] - pub const fn overflowing_shr_vartime(&self, shift: u32) -> (Self, ConstChoice) { + pub const fn overflowing_shr_vartime(&self, shift: u32) -> ConstCtOption { let mut limbs = [Limb::ZERO; LIMBS]; if shift >= Self::BITS { - return (Self::ZERO, ConstChoice::TRUE); + return ConstCtOption::none(Self::ZERO); } let shift_num = (shift / Limb::BITS) as usize; @@ -64,7 +66,7 @@ impl Uint { } if rem == 0 { - return (Self { limbs }, ConstChoice::FALSE); + return ConstCtOption::some(Self { limbs }); } let mut carry = Limb::ZERO; @@ -77,7 +79,7 @@ impl Uint { carry = new_carry; } - (Self { limbs }, ConstChoice::FALSE) + ConstCtOption::some(Self { limbs }) } /// Computes a right shift on a wide input as `(lo, hi)`. @@ -93,31 +95,41 @@ impl Uint { pub const fn overflowing_shr_vartime_wide( lower_upper: (Self, Self), shift: u32, - ) -> ((Self, Self), ConstChoice) { + ) -> ConstCtOption<(Self, Self)> { let (lower, upper) = lower_upper; if shift >= 2 * Self::BITS { - ((Self::ZERO, Self::ZERO), ConstChoice::TRUE) + ConstCtOption::none((Self::ZERO, Self::ZERO)) } else if shift >= Self::BITS { - let (lower, _) = upper.overflowing_shr_vartime(shift - Self::BITS); - ((lower, Self::ZERO), ConstChoice::FALSE) + let lower = upper + .overflowing_shr_vartime(shift - Self::BITS) + .expect("shift within range"); + ConstCtOption::some((lower, Self::ZERO)) } else { - let (new_upper, _) = upper.overflowing_shr_vartime(shift); - let (lower_hi, _) = upper.overflowing_shl_vartime(Self::BITS - shift); - let (lower_lo, _) = lower.overflowing_shr_vartime(shift); - ((lower_lo.bitor(&lower_hi), new_upper), ConstChoice::FALSE) + let new_upper = upper + .overflowing_shr_vartime(shift) + .expect("shift within range"); + let lower_hi = upper + .overflowing_shl_vartime(Self::BITS - shift) + .expect("shift within range"); + let lower_lo = lower + .overflowing_shr_vartime(shift) + .expect("shift within range"); + ConstCtOption::some((lower_lo.bitor(&lower_hi), new_upper)) } } /// Computes `self >> shift` in a panic-free manner, masking off bits of `shift` which would cause the shift to /// exceed the type's width. pub const fn wrapping_shr(&self, shift: u32) -> Self { - self.overflowing_shr(shift).0 + self.overflowing_shr(shift % Self::BITS) + .expect("shift within range") } /// Computes `self >> shift` in variable-time in a panic-free manner, masking off bits of `shift` which would cause /// the shift to exceed the type's width. pub const fn wrapping_shr_vartime(&self, shift: u32) -> Self { - self.overflowing_shr_vartime(shift).0 + self.overflowing_shr_vartime(shift % Self::BITS) + .expect("shift within range") } /// Computes `self >> 1` in constant-time. @@ -183,7 +195,7 @@ impl WrappingShr for Uint { #[cfg(test)] mod tests { - use crate::{ConstChoice, Uint, U128, U256}; + use crate::{Uint, U128, U256}; const N: U256 = U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); @@ -199,15 +211,12 @@ mod tests { #[test] fn shr256_const() { - assert_eq!(N.overflowing_shr(256), (U256::ZERO, ConstChoice::TRUE)); - assert_eq!( - N.overflowing_shr_vartime(256), - (U256::ZERO, ConstChoice::TRUE) - ); + assert!(N.overflowing_shr(256).is_none().is_true_vartime()); + assert!(N.overflowing_shr_vartime(256).is_none().is_true_vartime()); } #[test] - #[should_panic(expected = "attempt to shift right with overflow")] + #[should_panic(expected = "`shift` within the bit size of the integer")] fn shr256() { let _ = N >> 256; } @@ -215,24 +224,25 @@ mod tests { #[test] fn shr_wide_1_1_128() { assert_eq!( - Uint::overflowing_shr_vartime_wide((U128::ONE, U128::ONE), 128), - ((U128::ONE, U128::ZERO), ConstChoice::FALSE) + Uint::overflowing_shr_vartime_wide((U128::ONE, U128::ONE), 128).unwrap(), + (U128::ONE, U128::ZERO) ); } #[test] fn shr_wide_0_max_1() { assert_eq!( - Uint::overflowing_shr_vartime_wide((U128::ZERO, U128::MAX), 1), - ((U128::ONE << 127, U128::MAX >> 1), ConstChoice::FALSE) + Uint::overflowing_shr_vartime_wide((U128::ZERO, U128::MAX), 1).unwrap(), + (U128::ONE << 127, U128::MAX >> 1) ); } #[test] fn shr_wide_max_max_256() { - assert_eq!( - Uint::overflowing_shr_vartime_wide((U128::MAX, U128::MAX), 256), - ((U128::ZERO, U128::ZERO), ConstChoice::TRUE) + assert!( + Uint::overflowing_shr_vartime_wide((U128::MAX, U128::MAX), 256) + .is_none() + .is_true_vartime() ); } } diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index e402bd01f..2b1cdb4d8 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -17,7 +17,9 @@ impl Uint { // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. // Will not overflow since `b <= BITS`. - let (mut x, _overflow) = Self::ONE.overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`) + let mut x = Self::ONE + .overflowing_shl((self.bits() + 1) >> 1) + .expect("shift within range"); // ≥ √(`self`) // Repeat enough times to guarantee result has stabilized. let mut i = 0; @@ -29,8 +31,9 @@ impl Uint { // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` - let (nz_x, is_some) = NonZero::::const_new(x); - let (q, _) = self.div_rem(&nz_x); + let maybe_nz_x = NonZero::::const_new(x); + let (nz_x, is_some) = maybe_nz_x.components_ref(); + let (q, _) = self.div_rem(nz_x); // A protection in case `self == 0`, which will make `x == 0` let q = Self::select(&Self::ZERO, &q, is_some); @@ -53,12 +56,15 @@ impl Uint { // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. // Will not overflow since `b <= BITS`. - let (mut x, _overflow) = Self::ONE.overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`) + let mut x = Self::ONE + .overflowing_shl((self.bits() + 1) >> 1) + .expect("shift within range"); // ≥ √(`self`) // Stop right away if `x` is zero to avoid divizion by zero. while !x.cmp_vartime(&Self::ZERO).is_eq() { // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` - let q = self.wrapping_div_vartime(&NonZero::::const_new(x).0); + let q = self + .wrapping_div_vartime(&NonZero::::const_new(x).expect("ensured non-zero")); let t = x.wrapping_add(&q); let next_x = t.shr1(); diff --git a/tests/uint_proptests.rs b/tests/uint_proptests.rs index c09e406cd..4151cd66f 100644 --- a/tests/uint_proptests.rs +++ b/tests/uint_proptests.rs @@ -2,7 +2,7 @@ use crypto_bigint::{ modular::{DynResidue, DynResidueParams}, - ConstChoice, Encoding, Limb, NonZero, Word, U256, + Encoding, Limb, NonZero, Word, U256, }; use num_bigint::BigUint; use num_integer::Integer; @@ -63,13 +63,14 @@ proptest! { // Add a 50% probability of overflow. let shift = u32::from(shift) % (U256::BITS * 2); - let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << U256::BITS as usize) - BigUint::one())); - let (actual, overflow) = a.overflowing_shl_vartime(shift.into()); + let expected = to_uint(a_bi << shift as usize); + let actual = a.overflowing_shl_vartime(shift.into()); - assert_eq!(expected, actual); if shift >= U256::BITS { - assert_eq!(actual, U256::ZERO); - assert_eq!(overflow, ConstChoice::TRUE); + assert!(bool::from(actual.is_none())); + } + else { + assert_eq!(expected, actual.unwrap()); } } @@ -80,13 +81,14 @@ proptest! { // Add a 50% probability of overflow. let shift = u32::from(shift) % (U256::BITS * 2); - let expected = to_uint((a_bi << shift as usize) & ((BigUint::one() << U256::BITS as usize) - BigUint::one())); - let (actual, overflow) = a.overflowing_shl(shift); + let expected = to_uint(a_bi << shift as usize); + let actual = a.overflowing_shl(shift); - assert_eq!(expected, actual); if shift >= U256::BITS { - assert_eq!(actual, U256::ZERO); - assert_eq!(overflow, ConstChoice::TRUE); + assert!(bool::from(actual.is_none())); + } + else { + assert_eq!(expected, actual.unwrap()); } } @@ -98,12 +100,13 @@ proptest! { let shift = u32::from(shift) % (U256::BITS * 2); let expected = to_uint(a_bi >> shift as usize); - let (actual, overflow) = a.overflowing_shr_vartime(shift); + let actual = a.overflowing_shr_vartime(shift); - assert_eq!(expected, actual); if shift >= U256::BITS { - assert_eq!(actual, U256::ZERO); - assert_eq!(overflow, ConstChoice::TRUE); + assert!(bool::from(actual.is_none())); + } + else { + assert_eq!(expected, actual.unwrap()); } } @@ -115,12 +118,13 @@ proptest! { let shift = u32::from(shift) % (U256::BITS * 2); let expected = to_uint(a_bi >> shift as usize); - let (actual, overflow) = a.overflowing_shr(shift); + let actual = a.overflowing_shr(shift); - assert_eq!(expected, actual); if shift >= U256::BITS { - assert_eq!(actual, U256::ZERO); - assert_eq!(overflow, ConstChoice::TRUE); + assert!(bool::from(actual.is_none())); + } + else { + assert_eq!(expected, actual.unwrap()); } } @@ -278,11 +282,9 @@ proptest! { let a_bi = to_biguint(&a); let m_bi = BigUint::one() << k as usize; - let (actual, is_some) = a.inv_mod2k(k); - let (actual_vartime, is_some_vartime) = a.inv_mod2k_vartime(k); + let actual = a.inv_mod2k(k).unwrap(); + let actual_vartime = a.inv_mod2k_vartime(k).unwrap(); assert_eq!(actual, actual_vartime); - assert_eq!(is_some, ConstChoice::TRUE); - assert_eq!(is_some_vartime, ConstChoice::TRUE); if k == 0 { assert_eq!(actual, U256::ZERO); @@ -299,12 +301,14 @@ proptest! { let a_bi = to_biguint(&a); let b_bi = to_biguint(&b); - let expected_is_some = if a_bi.gcd(&b_bi) == BigUint::one() { ConstChoice::TRUE } else { ConstChoice::FALSE }; - let (actual, actual_is_some) = a.inv_mod(&b); + let expected_is_some = a_bi.gcd(&b_bi) == BigUint::one(); + let actual = a.inv_mod(&b); + let actual_is_some = bool::from(actual.is_some()); - assert_eq!(bool::from(expected_is_some), bool::from(actual_is_some)); + assert_eq!(expected_is_some, actual_is_some); - if actual_is_some.into() { + if actual_is_some { + let actual = actual.unwrap(); let inv_bi = to_biguint(&actual); let res = (inv_bi * a_bi) % b_bi; assert_eq!(res, BigUint::one());