Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion curve25519-dalek/benches/dalek_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ mod ristretto_benches {
let points: Vec<RistrettoPoint> = (0..size)
.map(|_| RistrettoPoint::try_from_rng(&mut rng).unwrap())
.collect();
b.iter(|| RistrettoPoint::double_and_compress_batch(&points));
b.iter(|| RistrettoPoint::double_and_compress_batch_alloc(&points));
},
);
}
Expand Down
4 changes: 2 additions & 2 deletions curve25519-dalek/src/edwards.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ impl EdwardsPoint {

// Compute the denominators in a batch
let mut denominators = eds.iter().map(|p| &p.Z - &p.Y).collect::<Vec<_>>();
FieldElement::batch_invert(&mut denominators);
FieldElement::invert_batch_alloc(&mut denominators);

// Now compute the Montgomery u coordinate for every point
let mut ret = Vec::with_capacity(eds.len());
Expand All @@ -621,7 +621,7 @@ impl EdwardsPoint {
#[cfg(feature = "alloc")]
pub fn compress_batch(inputs: &[EdwardsPoint]) -> Vec<CompressedEdwardsY> {
let mut zs = inputs.iter().map(|input| input.Z).collect::<Vec<_>>();
FieldElement::batch_invert(&mut zs);
FieldElement::invert_batch_alloc(&mut zs);

inputs
.iter()
Expand Down
33 changes: 24 additions & 9 deletions curve25519-dalek/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,32 @@ impl FieldElement {
(t19, t3)
}

/// Given a slice of pub(crate)lic `FieldElements`, replace each with its inverse.
///
/// When an input `FieldElement` is zero, its value is unchanged.
pub(crate) fn invert_batch<const N: usize>(inputs: &mut [FieldElement; N]) {
let mut scratch = [FieldElement::ONE; N];

Self::internal_invert_batch(inputs, &mut scratch);
}

/// Given a slice of pub(crate)lic `FieldElements`, replace each with its inverse.
///
/// When an input `FieldElement` is zero, its value is unchanged.
#[cfg(feature = "alloc")]
pub(crate) fn batch_invert(inputs: &mut [FieldElement]) {
pub(crate) fn invert_batch_alloc(inputs: &mut [FieldElement]) {
let n = inputs.len();
let mut scratch = vec![FieldElement::ONE; n];

Self::internal_invert_batch(inputs, &mut scratch);
}

fn internal_invert_batch(inputs: &mut [FieldElement], scratch: &mut [FieldElement]) {
// Montgomery’s Trick and Fast Implementation of Masked AES
// Genelle, Prouff and Quisquater
// Section 3.2

let n = inputs.len();
let mut scratch = vec![FieldElement::ONE; n];
debug_assert_eq!(inputs.len(), scratch.len());

// Keep an accumulator of all of the previous products
let mut acc = FieldElement::ONE;
Expand All @@ -234,12 +249,12 @@ impl FieldElement {

// Pass through the vector backwards to compute the inverses
// in place
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.into_iter().rev()) {
for (input, scratch) in inputs.iter_mut().rev().zip(scratch.iter_mut().rev()) {
let tmp = &acc * input;
// input <- acc * scratch, then acc <- tmp
// Again, we skip zeros in a constant-time way
let nz = !input.is_zero();
input.conditional_assign(&(&acc * &scratch), nz);
input.conditional_assign(&(&acc * scratch), nz);
acc.conditional_assign(&tmp, nz);
}
}
Expand Down Expand Up @@ -553,7 +568,7 @@ mod test {

#[test]
#[cfg(feature = "alloc")]
fn batch_invert_a_matches_nonbatched() {
fn invert_batch_a_matches_nonbatched() {
let a = FieldElement::from_bytes(&A_BYTES);
let ap58 = FieldElement::from_bytes(&AP58_BYTES);
let asq = FieldElement::from_bytes(&ASQ_BYTES);
Expand All @@ -562,7 +577,7 @@ mod test {
let a2 = &a + &a;
let a_list = vec![a, ap58, asq, ainv, a0, a2];
let mut ainv_list = a_list.clone();
FieldElement::batch_invert(&mut ainv_list[..]);
FieldElement::invert_batch_alloc(&mut ainv_list[..]);
for i in 0..6 {
assert_eq!(a_list[i].invert(), ainv_list[i]);
}
Expand Down Expand Up @@ -671,8 +686,8 @@ mod test {

#[test]
#[cfg(feature = "alloc")]
fn batch_invert_empty() {
FieldElement::batch_invert(&mut []);
fn invert_batch_empty() {
FieldElement::invert_batch_alloc(&mut []);
}

// The following two consts were generated with the following sage script:
Expand Down
174 changes: 111 additions & 63 deletions curve25519-dalek/src/ristretto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
#[cfg(feature = "alloc")]
use alloc::vec::Vec;

use core::array::TryFromSliceError;
use core::array::{self, TryFromSliceError};
use core::borrow::Borrow;
use core::fmt::Debug;
use core::iter::Sum;
Expand Down Expand Up @@ -532,6 +532,47 @@ impl RistrettoPoint {
CompressedRistretto(s.to_bytes())
}

/// Double-and-compress a batch of points. The Ristretto encoding
/// is not batchable, since it requires an inverse square root.
///
/// However, given input points \\( P\_1, \ldots, P\_n, \\)
/// it is possible to compute the encodings of their doubles \\(
/// \mathrm{enc}( \[2\]P\_1), \ldots, \mathrm{enc}( \[2\]P\_n ) \\)
/// in a batch.
///
#[cfg_attr(feature = "rand_core", doc = "```")]
#[cfg_attr(not(feature = "rand_core"), doc = "```ignore")]
/// # use curve25519_dalek::ristretto::RistrettoPoint;
/// use rand_core::{OsRng, TryRngCore};
///
/// # // Need fn main() here in comment so the doctest compiles
/// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
/// # fn main() {
/// let mut rng = OsRng.unwrap_err();
///
/// let points: [RistrettoPoint; 32] =
/// core::array::from_fn(|_| RistrettoPoint::random(&mut rng));
///
/// let compressed = RistrettoPoint::double_and_compress_batch(&points);
///
/// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
/// assert_eq!(*P2_compressed, (P + P).compress());
/// }
/// # }
/// ```
pub fn double_and_compress_batch<const N: usize>(
points: &[RistrettoPoint; N],
) -> [CompressedRistretto; N] {
let states: [BatchCompressState; N] =
array::from_fn(|i| BatchCompressState::from(&points[i]));

let mut invs: [FieldElement; N] = array::from_fn(|i| states[i].efgh());

FieldElement::invert_batch(&mut invs);

array::from_fn(|i| Self::internal_double_and_compress_batch(&states[i], &invs[i]))
}

/// Double-and-compress a batch of points. The Ristretto encoding
/// is not batchable, since it requires an inverse square root.
///
Expand All @@ -553,97 +594,68 @@ impl RistrettoPoint {
/// let points: Vec<RistrettoPoint> =
/// (0..32).map(|_| RistrettoPoint::random(&mut rng)).collect();
///
/// let compressed = RistrettoPoint::double_and_compress_batch(&points);
/// let compressed = RistrettoPoint::double_and_compress_batch_alloc(&points);
///
/// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
/// assert_eq!(*P2_compressed, (P + P).compress());
/// }
/// # }
/// ```
#[cfg(feature = "alloc")]
pub fn double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
pub fn double_and_compress_batch_alloc<'a, I>(points: I) -> Vec<CompressedRistretto>
where
I: IntoIterator<Item = &'a RistrettoPoint>,
{
#[derive(Copy, Clone, Debug)]
struct BatchCompressState {
e: FieldElement,
f: FieldElement,
g: FieldElement,
h: FieldElement,
eg: FieldElement,
fh: FieldElement,
}

impl BatchCompressState {
fn efgh(&self) -> FieldElement {
&self.eg * &self.fh
}
}

impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
#[rustfmt::skip] // keep alignment of explanatory comments
fn from(P: &'a RistrettoPoint) -> BatchCompressState {
let XX = P.0.X.square();
let YY = P.0.Y.square();
let ZZ = P.0.Z.square();
let dTT = &P.0.T.square() * &constants::EDWARDS_D;

let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
let f = &ZZ + &dTT; // = Z^2 + d*T^2
let g = &YY + &XX; // = Y^2 - a*X^2
let h = &ZZ - &dTT; // = Z^2 - d*T^2

let eg = &e * &g;
let fh = &f * &h;

BatchCompressState{ e, f, g, h, eg, fh }
}
}

let states: Vec<BatchCompressState> =
points.into_iter().map(BatchCompressState::from).collect();

let mut invs: Vec<FieldElement> = states.iter().map(|state| state.efgh()).collect();

FieldElement::batch_invert(&mut invs[..]);
FieldElement::invert_batch_alloc(&mut invs[..]);

states
.iter()
.zip(invs.iter())
.map(|(state, inv): (&BatchCompressState, &FieldElement)| {
let Zinv = &state.eg * inv;
let Tinv = &state.fh * inv;
Self::internal_double_and_compress_batch(state, inv)
})
.collect()
}

let mut magic = constants::INVSQRT_A_MINUS_D;
fn internal_double_and_compress_batch(
state: &BatchCompressState,
inv: &FieldElement,
) -> CompressedRistretto {
let Zinv = &state.eg * inv;
let Tinv = &state.fh * inv;

let negcheck1 = (&state.eg * &Zinv).is_negative();
let mut magic = constants::INVSQRT_A_MINUS_D;

let mut e = state.e;
let mut g = state.g;
let mut h = state.h;
let negcheck1 = (&state.eg * &Zinv).is_negative();

let minus_e = -&e;
let f_times_sqrta = &state.f * &constants::SQRT_M1;
let mut e = state.e;
let mut g = state.g;
let mut h = state.h;

e.conditional_assign(&state.g, negcheck1);
g.conditional_assign(&minus_e, negcheck1);
h.conditional_assign(&f_times_sqrta, negcheck1);
let minus_e = -&e;
let f_times_sqrta = &state.f * &constants::SQRT_M1;

magic.conditional_assign(&constants::SQRT_M1, negcheck1);
e.conditional_assign(&state.g, negcheck1);
g.conditional_assign(&minus_e, negcheck1);
h.conditional_assign(&f_times_sqrta, negcheck1);

let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
magic.conditional_assign(&constants::SQRT_M1, negcheck1);

g.conditional_negate(negcheck2);
let negcheck2 = (&(&h * &e) * &Zinv).is_negative();

let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
g.conditional_negate(negcheck2);

let s_is_negative = s.is_negative();
s.conditional_negate(s_is_negative);
let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));

CompressedRistretto(s.to_bytes())
})
.collect()
let s_is_negative = s.is_negative();
s.conditional_negate(s_is_negative);

CompressedRistretto(s.to_bytes())
}

/// Return the coset self + E\[4\], for debugging.
Expand Down Expand Up @@ -1156,6 +1168,42 @@ impl RistrettoBasepointTable {
}
}

#[derive(Copy, Clone, Debug)]
struct BatchCompressState {
e: FieldElement,
f: FieldElement,
g: FieldElement,
h: FieldElement,
eg: FieldElement,
fh: FieldElement,
}

impl BatchCompressState {
fn efgh(&self) -> FieldElement {
&self.eg * &self.fh
}
}

impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
#[rustfmt::skip] // keep alignment of explanatory comments
fn from(P: &'a RistrettoPoint) -> BatchCompressState {
let XX = P.0.X.square();
let YY = P.0.Y.square();
let ZZ = P.0.Z.square();
let dTT = &P.0.T.square() * &constants::EDWARDS_D;

let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
let f = &ZZ + &dTT; // = Z^2 + d*T^2
let g = &YY + &XX; // = Y^2 - a*X^2
let h = &ZZ - &dTT; // = Z^2 - d*T^2

let eg = &e * &g;
let fh = &f * &h;

BatchCompressState{ e, f, g, h, eg, fh }
}
}

// ------------------------------------------------------------------------
// Constant-time conditional selection
// ------------------------------------------------------------------------
Expand Down Expand Up @@ -1862,7 +1910,7 @@ mod test {
.collect();
points[500] = <RistrettoPoint as Group>::identity();

let compressed = RistrettoPoint::double_and_compress_batch(&points);
let compressed = RistrettoPoint::double_and_compress_batch_alloc(&points);

for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
assert_eq!(*P2_compressed, (P + P).compress());
Expand Down Expand Up @@ -1893,7 +1941,7 @@ mod test {
];

let multiplied_points: [_; 3] =
core::array::from_fn(|i| scalars[i].div_by_2() * points[i]);
array::from_fn(|i| scalars[i].div_by_2() * points[i]);
let compressed = RistrettoPoint::double_and_compress_batch(&multiplied_points);

for ((s, P), P2_compressed) in scalars.iter().zip(points).zip(compressed) {
Expand Down