diff --git a/curve25519-dalek/Cargo.toml b/curve25519-dalek/Cargo.toml index 1fbd72d00..72bff1de7 100644 --- a/curve25519-dalek/Cargo.toml +++ b/curve25519-dalek/Cargo.toml @@ -41,6 +41,7 @@ sha2 = { version = "0.11.0-rc.0", default-features = false } bincode = "1" criterion = { version = "0.5", features = ["html_reports"] } hex = "0.4.2" +proptest = "1" rand = "0.9" rand_core = { version = "0.9", default-features = false, features = ["os_rng"] } diff --git a/curve25519-dalek/src/backend/serial/u32/scalar.rs b/curve25519-dalek/src/backend/serial/u32/scalar.rs index cc21139bc..4c34ee2d4 100644 --- a/curve25519-dalek/src/backend/serial/u32/scalar.rs +++ b/curve25519-dalek/src/backend/serial/u32/scalar.rs @@ -197,15 +197,33 @@ impl Scalar29 { } // conditionally add l if the difference is negative + difference.conditional_add_l(Choice::from((borrow >> 31) as u8)); + difference + } + + pub(crate) fn conditional_add_l(&mut self, condition: Choice) -> u32 { let mut carry: u32 = 0; + let mask = (1u32 << 29) - 1; + for i in 0..9 { - let underflow = Choice::from((borrow >> 31) as u8); - let addend = u32::conditional_select(&0, &constants::L[i], underflow); - carry = (carry >> 29) + difference[i] + addend; - difference[i] = carry & mask; + let addend = u32::conditional_select(&0, &constants::L[i], condition); + carry = (carry >> 29) + self[i] + addend; + self[i] = carry & mask; } + carry + } - difference + /// Compute a raw in-place carrying right shift over the limbs. + #[inline(always)] + pub(crate) fn shr1_assign(&mut self) -> u32 { + let mut carry: u32 = 0; + for i in (0..9).rev() { + let limb = self[i]; + let next_carry = limb & 1; + self[i] = (limb >> 1) | (carry << 28); + carry = next_carry; + } + carry } /// Compute `a * b`. diff --git a/curve25519-dalek/src/backend/serial/u64/scalar.rs b/curve25519-dalek/src/backend/serial/u64/scalar.rs index 5bcbb72c3..f4eb4a2af 100644 --- a/curve25519-dalek/src/backend/serial/u64/scalar.rs +++ b/curve25519-dalek/src/backend/serial/u64/scalar.rs @@ -186,15 +186,34 @@ impl Scalar52 { } // conditionally add l if the difference is negative + difference.conditional_add_l(Choice::from((borrow >> 63) as u8)); + difference + } + + pub(crate) fn conditional_add_l(&mut self, condition: Choice) -> u64 { let mut carry: u64 = 0; + let mask = (1u64 << 52) - 1; + for i in 0..5 { - let underflow = Choice::from((borrow >> 63) as u8); - let addend = u64::conditional_select(&0, &constants::L[i], underflow); - carry = (carry >> 52) + difference[i] + addend; - difference[i] = carry & mask; + let addend = u64::conditional_select(&0, &constants::L[i], condition); + carry = (carry >> 52) + self[i] + addend; + self[i] = carry & mask; } - difference + carry + } + + /// Compute a raw in-place carrying right shift over the limbs. + #[inline(always)] + pub(crate) fn shr1_assign(&mut self) -> u64 { + let mut carry: u64 = 0; + for i in (0..5).rev() { + let limb = self[i]; + let next_carry = limb & 1; + self[i] = (limb >> 1) | (carry << 51); + carry = next_carry; + } + carry } /// Compute `a * b` diff --git a/curve25519-dalek/src/ristretto.rs b/curve25519-dalek/src/ristretto.rs index ca5fa88ec..8b867930d 100644 --- a/curve25519-dalek/src/ristretto.rs +++ b/curve25519-dalek/src/ristretto.rs @@ -1321,6 +1321,8 @@ impl Zeroize for RistrettoPoint { mod test { use super::*; use crate::edwards::CompressedEdwardsY; + #[cfg(feature = "group")] + use proptest::prelude::*; use rand_core::{OsRng, TryRngCore}; @@ -1867,6 +1869,39 @@ mod test { } } + #[cfg(feature = "group")] + proptest! { + #[test] + fn multiply_double_and_compress_random_points( + p1 in any::<[u8; 64]>(), + p2 in any::<[u8; 64]>(), + s1 in any::<[u8; 32]>(), + s2 in any::<[u8; 32]>(), + ) { + use group::Group; + + let scalars = [ + Scalar::from_bytes_mod_order(s1), + Scalar::ZERO, + Scalar::from_bytes_mod_order(s2), + ]; + + let points = [ + RistrettoPoint::from_uniform_bytes(&p1), + ::identity(), + RistrettoPoint::from_uniform_bytes(&p2), + ]; + + let multiplied_points: [_; 3] = + core::array::from_fn(|i| scalars[i].div_by_2() * points[i]); + let compressed = RistrettoPoint::double_and_compress_batch(&multiplied_points); + + for ((s, P), P2_compressed) in scalars.iter().zip(points).zip(compressed) { + prop_assert_eq!(P2_compressed, (s * P).compress()); + } + } + } + #[test] #[cfg(feature = "alloc")] fn vartime_precomputed_vs_nonprecomputed_multiscalar() { diff --git a/curve25519-dalek/src/scalar.rs b/curve25519-dalek/src/scalar.rs index fc8939fb4..7fc6e2783 100644 --- a/curve25519-dalek/src/scalar.rs +++ b/curve25519-dalek/src/scalar.rs @@ -831,6 +831,22 @@ impl Scalar { ret } + /// Compute `b` such that `b + b = a mod modulus`. + pub fn div_by_2(&self) -> Self { + // We are looking for such `b` that `b + b = a mod modulus`. + // Two possibilities: + // - if `a` is even, we can just divide by 2; + // - if `a` is odd, we divide `(a + modulus)` by 2. + let is_odd = Choice::from(self.as_bytes()[0] & 1); + let mut scalar = self.unpack(); + scalar.conditional_add_l(is_odd); + + let carry = scalar.shr1_assign(); + debug_assert_eq!(carry, 0); + + scalar.pack() + } + /// Get the bits of the scalar, in little-endian order pub(crate) fn bits_le(&self) -> impl DoubleEndedIterator + '_ { (0..256).map(|i| { @@ -1677,6 +1693,23 @@ pub(crate) mod test { } } + #[test] + fn div_by_2() { + // test a range of small scalars + for i in 0u64..32 { + let scalar = Scalar::from(i); + let dividend = scalar.div_by_2(); + assert_eq!(scalar, dividend + dividend); + } + + // test a range of scalars near the modulus + for i in 0u64..32 { + let scalar = Scalar::ZERO - Scalar::from(i); + let dividend = scalar.div_by_2(); + assert_eq!(scalar, dividend + dividend); + } + } + #[test] fn reduce() { let biggest = Scalar::from_bytes_mod_order([0xff; 32]);