diff --git a/src/const_choice.rs b/src/const_choice.rs index d45f6f245..e3203fd75 100644 --- a/src/const_choice.rs +++ b/src/const_choice.rs @@ -1,6 +1,6 @@ use subtle::{Choice, CtOption}; -use crate::{modular::BernsteinYangInverter, Limb, NonZero, Odd, Uint, Word}; +use crate::{modular::BernsteinYangInverter, Limb, NonZero, Odd, Uint, WideWord, Word}; /// A boolean value returned by constant-time `const fn`s. // TODO: should be replaced by `subtle::Choice` or `CtOption` @@ -49,6 +49,14 @@ impl ConstChoice { Self(value.wrapping_neg()) } + /// Returns the truthy value if `value == 1`, and the falsy value if `value == 0`. + /// Panics for other values. + #[inline] + pub(crate) const fn from_wide_word_lsb(value: WideWord) -> Self { + debug_assert!(value == 0 || value == 1); + Self(value.wrapping_neg() as Word) + } + #[inline] pub(crate) const fn from_u32_lsb(value: u32) -> Self { debug_assert!(value == 0 || value == 1); @@ -129,6 +137,14 @@ impl ConstChoice { Self::from_word_lsb(bit) } + /// Returns the truthy value if `x <= y` and the falsy value otherwise. + #[inline] + pub(crate) const fn from_wide_word_le(x: WideWord, y: WideWord) -> Self { + // See "Hacker's Delight" 2nd ed, section 2-12 (Comparison predicates) + let bit = (((!x) | y) & ((x ^ y) | !(y.wrapping_sub(x)))) >> (WideWord::BITS - 1); + Self::from_wide_word_lsb(bit) + } + /// Returns the truthy value if `x <= y` and the falsy value otherwise. #[inline] pub(crate) const fn from_u32_le(x: u32, y: u32) -> Self { @@ -172,6 +188,13 @@ impl ConstChoice { a ^ (self.0 & (a ^ b)) } + /// Return `b` if `self` is truthy, otherwise return `a`. + #[inline] + pub(crate) const fn select_wide_word(&self, a: WideWord, b: WideWord) -> WideWord { + let mask = ((self.0 as WideWord) << Word::BITS) | (self.0 as WideWord); + a ^ (mask & (a ^ b)) + } + /// Return `b` if `self` is truthy, otherwise return `a`. #[inline] pub(crate) const fn select_u32(&self, a: u32, b: u32) -> u32 { @@ -423,7 +446,7 @@ impl #[cfg(test)] mod tests { use super::ConstChoice; - use crate::Word; + use crate::{WideWord, Word}; #[test] fn from_u64_lsb() { @@ -445,6 +468,13 @@ mod tests { assert_eq!(ConstChoice::from_word_gt(6, 5), ConstChoice::TRUE); } + #[test] + fn from_wide_word_le() { + assert_eq!(ConstChoice::from_wide_word_le(4, 5), ConstChoice::TRUE); + assert_eq!(ConstChoice::from_wide_word_le(5, 5), ConstChoice::TRUE); + assert_eq!(ConstChoice::from_wide_word_le(6, 5), ConstChoice::FALSE); + } + #[test] fn select_u32() { let a: u32 = 1; @@ -468,4 +498,12 @@ mod tests { assert_eq!(ConstChoice::TRUE.select_word(a, b), b); assert_eq!(ConstChoice::FALSE.select_word(a, b), a); } + + #[test] + fn select_wide_word() { + let a: WideWord = (1 << Word::BITS) + 1; + let b: WideWord = (3 << Word::BITS) + 4; + assert_eq!(ConstChoice::TRUE.select_wide_word(a, b), b); + assert_eq!(ConstChoice::FALSE.select_wide_word(a, b), a); + } } diff --git a/src/uint/div_limb.rs b/src/uint/div_limb.rs index b0bee0463..595518b8d 100644 --- a/src/uint/div_limb.rs +++ b/src/uint/div_limb.rs @@ -156,28 +156,26 @@ pub(crate) const fn div3by2( u0: Word, reciprocal: &Reciprocal, v0: Word, -) -> (Word, Word) { +) -> (Word, WideWord) { // This method corresponds to Algorithm Q: // https://janmr.com/blog/2014/04/basic-multiple-precision-long-division/ - let maxed = ConstChoice::from_word_eq(u2, reciprocal.divisor_normalized); - let (mut quo, mut rem) = div2by1(maxed.select_word(u2, 0), u1, reciprocal); + let q_maxed = ConstChoice::from_word_eq(u2, reciprocal.divisor_normalized); + let (mut quo, rem) = div2by1(q_maxed.select_word(u2, 0), u1, reciprocal); // When the leading dividend word equals the leading divisor word, cap the quotient // at Word::MAX and set the remainder to the sum of the top dividend words. - quo = maxed.select_word(quo, Word::MAX); - rem = maxed.select_word(rem, u2.saturating_add(u1)); + quo = q_maxed.select_word(quo, Word::MAX); + let mut rem = q_maxed.select_wide_word(rem as WideWord, (u2 as WideWord) + (u1 as WideWord)); let mut i = 0; while i < 2 { let qy = (quo as WideWord) * (v0 as WideWord); - let rx = ((rem as WideWord) << Word::BITS) | (u0 as WideWord); - // Constant-time check for q*y[-2] < r*x[-1], based on ConstChoice::from_word_lt - let diff = ConstChoice::from_word_lsb( - ((((!rx) & qy) | (((!rx) | qy) & (rx.wrapping_sub(qy)))) >> (WideWord::BITS - 1)) - as Word, - ); - quo = diff.select_word(quo, quo.saturating_sub(1)); - rem = diff.select_word(rem, rem.saturating_add(reciprocal.divisor_normalized)); + let rx = (rem << Word::BITS) | (u0 as WideWord); + // If r < b and q*y[-2] > r*x[-1], then set q = q - 1 and r = r + v1 + let done = ConstChoice::from_word_nonzero((rem >> Word::BITS) as Word) + .or(ConstChoice::from_wide_word_le(qy, rx)); + quo = done.select_word(quo.saturating_sub(1), quo); + rem = done.select_wide_word(rem + (reciprocal.divisor_normalized as WideWord), rem); i += 1; }