Skip to content

Commit 6b929cf

Browse files
committed
Implement BatchNormalize for EdwardsPoint
1 parent 02b4363 commit 6b929cf

File tree

3 files changed

+184
-7
lines changed

3 files changed

+184
-7
lines changed

ed448-goldilocks/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ assert_eq!(public_key, EdwardsPoint::GENERATOR + EdwardsPoint::GENERATOR);
2929

3030
let secret_key = EdwardsScalar::try_from_rng(&mut OsRng).unwrap();
3131
let public_key = EdwardsPoint::GENERATOR * &secret_key;
32-
let compressed_public_key = public_key.compress();
32+
let compressed_public_key = public_key.to_affine().compress();
3333

3434
assert_eq!(compressed_public_key.to_bytes().len(), 57);
3535

@@ -38,12 +38,12 @@ let input = hex_literal::hex!("c8c6c8f584e0c25efdb6af5ad234583c56dedd7c33e0c8934
3838
let expected_scalar = EdwardsScalar::from_canonical_bytes(&input.into()).unwrap();
3939
assert_eq!(hashed_scalar, expected_scalar);
4040

41-
let hashed_point = Ed448::hash_from_bytes::<ExpandMsgXof<Shake256>>(&[b"test"], &[b"edwards448_XOF:SHAKE256_ELL2_RO_"]).unwrap();
41+
let hashed_point = Ed448::hash_from_bytes::<ExpandMsgXof<Shake256>>(&[b"test"], &[b"edwards448_XOF:SHAKE256_ELL2_RO_"]).unwrap().to_affine();
4242
let expected = hex_literal::hex!("d15c4427b5c5611a53593c2be611fd3635b90272d331c7e6721ad3735e95dd8b9821f8e4e27501ce01aa3c913114052dce2e91e8ca050f4980");
4343
let expected_point = CompressedEdwardsY(expected).decompress().unwrap();
4444
assert_eq!(hashed_point, expected_point);
4545

46-
let hashed_point = EdwardsPoint::hash_with_defaults(b"test");
46+
let hashed_point = EdwardsPoint::hash_with_defaults(b"test").to_affine();
4747
assert_eq!(hashed_point, expected_point);
4848
```
4949

ed448-goldilocks/src/edwards/extended.rs

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ use crate::edwards::affine::PointBytes;
77
use crate::field::FieldElement;
88
use crate::*;
99
use elliptic_curve::{
10-
CurveGroup, Error,
10+
BatchNormalize, CurveGroup, Error,
1111
array::Array,
1212
group::{Group, GroupEncoding, cofactor::CofactorGroup, prime::PrimeGroup},
13-
ops::LinearCombination,
13+
ops::{BatchInvert, LinearCombination},
1414
point::NonIdentity,
1515
};
1616
use hash2curve::ExpandMsgXof;
@@ -290,6 +290,14 @@ impl CurveGroup for EdwardsPoint {
290290
fn to_affine(&self) -> AffinePoint {
291291
self.to_affine()
292292
}
293+
294+
#[cfg(feature = "alloc")]
295+
#[inline]
296+
fn batch_normalize(projective: &[Self], affine: &mut [Self::AffineRepr]) {
297+
assert_eq!(projective.len(), affine.len());
298+
let mut zs = alloc::vec![FieldElement::ONE; projective.len()];
299+
batch_normalize_generic(projective, zs.as_mut_slice(), affine);
300+
}
293301
}
294302

295303
impl EdwardsPoint {
@@ -671,11 +679,76 @@ impl<'de> serdect::serde::Deserialize<'de> for EdwardsPoint {
671679

672680
impl elliptic_curve::zeroize::DefaultIsZeroes for EdwardsPoint {}
673681

682+
impl<const N: usize> BatchNormalize<[EdwardsPoint; N]> for EdwardsPoint {
683+
type Output = [<Self as CurveGroup>::AffineRepr; N];
684+
685+
#[inline]
686+
fn batch_normalize(points: &[Self; N]) -> [<Self as CurveGroup>::AffineRepr; N] {
687+
let zs = [FieldElement::ONE; N];
688+
let mut affine_points = [AffinePoint::IDENTITY; N];
689+
batch_normalize_generic(points, zs, &mut affine_points);
690+
affine_points
691+
}
692+
}
693+
694+
#[cfg(feature = "alloc")]
695+
impl BatchNormalize<[EdwardsPoint]> for EdwardsPoint {
696+
type Output = Vec<<Self as CurveGroup>::AffineRepr>;
697+
698+
#[inline]
699+
fn batch_normalize(points: &[Self]) -> Vec<<Self as CurveGroup>::AffineRepr> {
700+
use alloc::vec;
701+
702+
let mut zs = vec![FieldElement::ONE; points.len()];
703+
let mut affine_points = vec![AffinePoint::IDENTITY; points.len()];
704+
batch_normalize_generic(points, zs.as_mut_slice(), &mut affine_points);
705+
affine_points
706+
}
707+
}
708+
709+
/// Generic implementation of batch normalization.
710+
fn batch_normalize_generic<P, Z, I, O>(points: &P, mut zs: Z, out: &mut O)
711+
where
712+
FieldElement: BatchInvert<Z, Output = CtOption<I>>,
713+
P: AsRef<[EdwardsPoint]> + ?Sized,
714+
Z: AsMut<[FieldElement]>,
715+
I: AsRef<[FieldElement]>,
716+
O: AsMut<[AffinePoint]> + ?Sized,
717+
{
718+
let points = points.as_ref();
719+
let out = out.as_mut();
720+
721+
for (i, point) in points.iter().enumerate() {
722+
// Even a single zero value will fail inversion for the entire batch.
723+
// Put a dummy value (above `FieldElement::ONE`) so inversion succeeds
724+
// and treat that case specially later-on.
725+
zs.as_mut()[i].conditional_assign(&point.Z, !point.Z.ct_eq(&FieldElement::ZERO));
726+
}
727+
728+
// This is safe to unwrap since we assured that all elements are non-zero
729+
let zs_inverses = <FieldElement as BatchInvert<Z>>::batch_invert(zs)
730+
.expect("all elements should be non-zero");
731+
732+
for i in 0..out.len() {
733+
// If the `z` coordinate is non-zero, we can use it to invert;
734+
// otherwise it defaults to the `IDENTITY` value.
735+
out[i] = AffinePoint::conditional_select(
736+
&AffinePoint {
737+
x: points[i].X * zs_inverses.as_ref()[i],
738+
y: points[i].Y * zs_inverses.as_ref()[i],
739+
},
740+
&AffinePoint::IDENTITY,
741+
points[i].Z.ct_eq(&FieldElement::ZERO),
742+
);
743+
}
744+
}
745+
674746
#[cfg(test)]
675747
mod tests {
676748
use super::*;
677749
use elliptic_curve::Field;
678750
use hex_literal::hex;
751+
use rand_core::OsRng;
679752

680753
fn hex_to_field(hex: &'static str) -> FieldElement {
681754
assert_eq!(hex.len(), 56 * 2);
@@ -970,4 +1043,33 @@ mod tests {
9701043

9711044
assert_eq!(computed_commitment, expected_commitment);
9721045
}
1046+
1047+
#[test]
1048+
fn batch_normalize() {
1049+
let points: [EdwardsPoint; 2] = [
1050+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1051+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1052+
];
1053+
1054+
let affine_points = <EdwardsPoint as BatchNormalize<_>>::batch_normalize(&points);
1055+
1056+
for (point, affine_point) in points.into_iter().zip(affine_points) {
1057+
assert_eq!(affine_point, point.to_affine());
1058+
}
1059+
}
1060+
1061+
#[test]
1062+
#[cfg(feature = "alloc")]
1063+
fn batch_normalize_alloc() {
1064+
let points = alloc::vec![
1065+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1066+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1067+
];
1068+
1069+
let affine_points = <EdwardsPoint as BatchNormalize<_>>::batch_normalize(points.as_slice());
1070+
1071+
for (point, affine_point) in points.into_iter().zip(affine_points) {
1072+
assert_eq!(affine_point, point.to_affine());
1073+
}
1074+
}
9731075
}

ed448-goldilocks/src/field/element.rs

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
11
use core::fmt::{self, Debug, Display, Formatter, LowerHex, UpperHex};
2+
use core::iter::{Product, Sum};
23
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
34

4-
use super::ConstMontyType;
5+
use super::{ConstMontyType, MODULUS};
56
use crate::{
67
AffinePoint, Decaf448, DecafPoint, Ed448, EdwardsPoint,
78
curve::twedwards::extended::ExtendedPoint as TwistedExtendedPoint,
89
};
910
use elliptic_curve::{
11+
Field,
1012
array::Array,
1113
bigint::{
1214
Integer, NonZero, U448, U704,
1315
consts::{U56, U84, U88},
16+
modular::ConstMontyParams,
1417
},
1518
group::cofactor::CofactorGroup,
1619
zeroize::DefaultIsZeroes,
1720
};
1821
use hash2curve::{FromOkm, MapToCurve};
19-
use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
22+
use rand_core::TryRngCore;
23+
use subtle::{
24+
Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess,
25+
CtOption,
26+
};
2027

2128
#[derive(Clone, Copy, Default)]
2229
pub struct FieldElement(pub(crate) ConstMontyType);
@@ -225,6 +232,68 @@ impl MapToCurve for Decaf448 {
225232
}
226233
}
227234

235+
impl Sum for FieldElement {
236+
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
237+
iter.reduce(Add::add).unwrap_or(Self::ZERO)
238+
}
239+
}
240+
241+
impl<'a> Sum<&'a FieldElement> for FieldElement {
242+
fn sum<I: Iterator<Item = &'a FieldElement>>(iter: I) -> Self {
243+
iter.copied().sum()
244+
}
245+
}
246+
247+
impl Product for FieldElement {
248+
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
249+
iter.reduce(Mul::mul).unwrap_or(Self::ONE)
250+
}
251+
}
252+
253+
impl<'a> Product<&'a FieldElement> for FieldElement {
254+
fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
255+
iter.copied().product()
256+
}
257+
}
258+
259+
impl Field for FieldElement {
260+
const ZERO: Self = Self::ZERO;
261+
const ONE: Self = Self::ONE;
262+
263+
fn try_from_rng<R: TryRngCore + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
264+
let mut bytes = [0; 56];
265+
266+
loop {
267+
rng.try_fill_bytes(&mut bytes)?;
268+
if let Some(fe) = Self::from_repr(&bytes).into() {
269+
return Ok(fe);
270+
}
271+
}
272+
}
273+
274+
fn square(&self) -> Self {
275+
self.square()
276+
}
277+
278+
fn double(&self) -> Self {
279+
self.double()
280+
}
281+
282+
fn invert(&self) -> CtOption<Self> {
283+
CtOption::from(self.0.invert()).map(Self)
284+
}
285+
286+
fn sqrt(&self) -> CtOption<Self> {
287+
let sqrt = self.sqrt();
288+
CtOption::new(sqrt, sqrt.square().ct_eq(self))
289+
}
290+
291+
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
292+
let (result, is_square) = Self::sqrt_ratio(num, div);
293+
(is_square, result)
294+
}
295+
}
296+
228297
impl FieldElement {
229298
pub const A_PLUS_TWO_OVER_FOUR: Self = Self(ConstMontyType::new(&U448::from_be_hex(
230299
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000098aa",
@@ -315,6 +384,12 @@ impl FieldElement {
315384
Self(ConstMontyType::new(&U448::from_le_slice(bytes)))
316385
}
317386

387+
pub fn from_repr(bytes: &[u8; 56]) -> CtOption<Self> {
388+
let integer = U448::from_le_slice(bytes);
389+
let is_some = integer.ct_lt(MODULUS::PARAMS.modulus());
390+
CtOption::new(Self(ConstMontyType::new(&integer)), is_some)
391+
}
392+
318393
pub fn double(&self) -> Self {
319394
Self(self.0.double())
320395
}

0 commit comments

Comments
 (0)