Skip to content

Commit 215e601

Browse files
committed
Implement BatchNormalize for EdwardsPoint
1 parent f283f78 commit 215e601

File tree

2 files changed

+181
-4
lines changed

2 files changed

+181
-4
lines changed

ed448-goldilocks/src/edwards/extended.rs

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ use crate::edwards::affine::PointBytes;
1010
use crate::field::FieldElement;
1111
use crate::*;
1212
use elliptic_curve::{
13-
CurveGroup, Error,
13+
BatchNormalize, CurveGroup, Error,
1414
array::Array,
1515
group::{Group, GroupEncoding, cofactor::CofactorGroup, prime::PrimeGroup},
16-
ops::LinearCombination,
16+
ops::{BatchInvert, LinearCombination},
1717
point::NonIdentity,
1818
};
1919
use hash2curve::ExpandMsgXof;
@@ -296,6 +296,14 @@ impl CurveGroup for EdwardsPoint {
296296
fn to_affine(&self) -> AffinePoint {
297297
self.to_affine()
298298
}
299+
300+
#[cfg(feature = "alloc")]
301+
#[inline]
302+
fn batch_normalize(projective: &[Self], affine: &mut [Self::AffineRepr]) {
303+
assert_eq!(projective.len(), affine.len());
304+
let mut zs = alloc::vec![FieldElement::ONE; projective.len()];
305+
batch_normalize_generic(projective, zs.as_mut_slice(), affine);
306+
}
299307
}
300308

301309
impl EdwardsPoint {
@@ -714,11 +722,76 @@ impl<'de> serdect::serde::Deserialize<'de> for EdwardsPoint {
714722

715723
impl elliptic_curve::zeroize::DefaultIsZeroes for EdwardsPoint {}
716724

725+
impl<const N: usize> BatchNormalize<[EdwardsPoint; N]> for EdwardsPoint {
726+
type Output = [<Self as CurveGroup>::AffineRepr; N];
727+
728+
#[inline]
729+
fn batch_normalize(points: &[Self; N]) -> [<Self as CurveGroup>::AffineRepr; N] {
730+
let zs = [FieldElement::ONE; N];
731+
let mut affine_points = [AffinePoint::IDENTITY; N];
732+
batch_normalize_generic(points, zs, &mut affine_points);
733+
affine_points
734+
}
735+
}
736+
737+
#[cfg(feature = "alloc")]
738+
impl BatchNormalize<[EdwardsPoint]> for EdwardsPoint {
739+
type Output = Vec<<Self as CurveGroup>::AffineRepr>;
740+
741+
#[inline]
742+
fn batch_normalize(points: &[Self]) -> Vec<<Self as CurveGroup>::AffineRepr> {
743+
use alloc::vec;
744+
745+
let mut zs = vec![FieldElement::ONE; points.len()];
746+
let mut affine_points = vec![AffinePoint::IDENTITY; points.len()];
747+
batch_normalize_generic(points, zs.as_mut_slice(), &mut affine_points);
748+
affine_points
749+
}
750+
}
751+
752+
/// Generic implementation of batch normalization.
753+
fn batch_normalize_generic<P, Z, I, O>(points: &P, mut zs: Z, out: &mut O)
754+
where
755+
FieldElement: BatchInvert<Z, Output = CtOption<I>>,
756+
P: AsRef<[EdwardsPoint]> + ?Sized,
757+
Z: AsMut<[FieldElement]>,
758+
I: AsRef<[FieldElement]>,
759+
O: AsMut<[AffinePoint]> + ?Sized,
760+
{
761+
let points = points.as_ref();
762+
let out = out.as_mut();
763+
764+
for (i, point) in points.iter().enumerate() {
765+
// Even a single zero value will fail inversion for the entire batch.
766+
// Put a dummy value (above `FieldElement::ONE`) so inversion succeeds
767+
// and treat that case specially later-on.
768+
zs.as_mut()[i].conditional_assign(&point.Z, !point.Z.ct_eq(&FieldElement::ZERO));
769+
}
770+
771+
// This is safe to unwrap since we assured that all elements are non-zero
772+
let zs_inverses = <FieldElement as BatchInvert<Z>>::batch_invert(zs)
773+
.expect("all elements should be non-zero");
774+
775+
for i in 0..out.len() {
776+
// If the `z` coordinate is non-zero, we can use it to invert;
777+
// otherwise it defaults to the `IDENTITY` value.
778+
out[i] = AffinePoint::conditional_select(
779+
&AffinePoint {
780+
x: points[i].X * zs_inverses.as_ref()[i],
781+
y: points[i].Y * zs_inverses.as_ref()[i],
782+
},
783+
&AffinePoint::IDENTITY,
784+
points[i].Z.ct_eq(&FieldElement::ZERO),
785+
);
786+
}
787+
}
788+
717789
#[cfg(test)]
718790
mod tests {
719791
use super::*;
720792
use elliptic_curve::Field;
721793
use hex_literal::hex;
794+
use rand_core::OsRng;
722795

723796
fn hex_to_field(hex: &'static str) -> FieldElement {
724797
assert_eq!(hex.len(), 56 * 2);
@@ -1013,4 +1086,33 @@ mod tests {
10131086

10141087
assert_eq!(computed_commitment, expected_commitment);
10151088
}
1089+
1090+
#[test]
1091+
fn batch_normalize() {
1092+
let points: [EdwardsPoint; 2] = [
1093+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1094+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1095+
];
1096+
1097+
let affine_points = <EdwardsPoint as BatchNormalize<_>>::batch_normalize(&points);
1098+
1099+
for (point, affine_point) in points.into_iter().zip(affine_points) {
1100+
assert_eq!(affine_point, point.to_affine());
1101+
}
1102+
}
1103+
1104+
#[test]
1105+
#[cfg(feature = "alloc")]
1106+
fn batch_normalize_alloc() {
1107+
let points = alloc::vec![
1108+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1109+
EdwardsPoint::try_from_rng(&mut OsRng).unwrap(),
1110+
];
1111+
1112+
let affine_points = <EdwardsPoint as BatchNormalize<_>>::batch_normalize(points.as_slice());
1113+
1114+
for (point, affine_point) in points.into_iter().zip(affine_points) {
1115+
assert_eq!(affine_point, point.to_affine());
1116+
}
1117+
}
10161118
}

ed448-goldilocks/src/field/element.rs

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
11
use core::fmt::{self, Debug, Display, Formatter, LowerHex, UpperHex};
2+
use core::iter::{Product, Sum};
23
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
34

4-
use super::ConstMontyType;
5+
use super::{ConstMontyType, MODULUS};
56
use crate::{
67
AffinePoint, Decaf448, DecafPoint, Ed448, EdwardsPoint,
78
curve::twedwards::extended::ExtendedPoint as TwistedExtendedPoint,
89
};
910
use elliptic_curve::{
11+
Field,
1012
array::Array,
1113
bigint::{
1214
Integer, NonZero, U448, U704,
1315
consts::{U56, U84, U88},
16+
modular::ConstMontyParams,
1417
},
1518
group::cofactor::CofactorGroup,
1619
zeroize::DefaultIsZeroes,
1720
};
1821
use hash2curve::{FromOkm, MapToCurve};
19-
use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
22+
use rand_core::TryRngCore;
23+
use subtle::{
24+
Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess,
25+
CtOption,
26+
};
2027

2128
#[derive(Clone, Copy, Default)]
2229
pub struct FieldElement(pub(crate) ConstMontyType);
@@ -225,6 +232,68 @@ impl MapToCurve for Decaf448 {
225232
}
226233
}
227234

235+
impl Sum for FieldElement {
236+
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
237+
iter.reduce(Add::add).unwrap_or(Self::ZERO)
238+
}
239+
}
240+
241+
impl<'a> Sum<&'a FieldElement> for FieldElement {
242+
fn sum<I: Iterator<Item = &'a FieldElement>>(iter: I) -> Self {
243+
iter.copied().sum()
244+
}
245+
}
246+
247+
impl Product for FieldElement {
248+
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
249+
iter.reduce(Mul::mul).unwrap_or(Self::ONE)
250+
}
251+
}
252+
253+
impl<'a> Product<&'a FieldElement> for FieldElement {
254+
fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
255+
iter.copied().product()
256+
}
257+
}
258+
259+
impl Field for FieldElement {
260+
const ZERO: Self = Self::ZERO;
261+
const ONE: Self = Self::ONE;
262+
263+
fn try_from_rng<R: TryRngCore + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
264+
let mut bytes = [0; 56];
265+
266+
loop {
267+
rng.try_fill_bytes(&mut bytes)?;
268+
if let Some(fe) = Self::from_repr(&bytes).into() {
269+
return Ok(fe);
270+
}
271+
}
272+
}
273+
274+
fn square(&self) -> Self {
275+
self.square()
276+
}
277+
278+
fn double(&self) -> Self {
279+
self.double()
280+
}
281+
282+
fn invert(&self) -> CtOption<Self> {
283+
CtOption::from(self.0.invert()).map(Self)
284+
}
285+
286+
fn sqrt(&self) -> CtOption<Self> {
287+
let sqrt = self.sqrt();
288+
CtOption::new(sqrt, sqrt.square().ct_eq(self))
289+
}
290+
291+
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
292+
let (result, is_square) = Self::sqrt_ratio(num, div);
293+
(is_square, result)
294+
}
295+
}
296+
228297
impl FieldElement {
229298
pub const A_PLUS_TWO_OVER_FOUR: Self = Self(ConstMontyType::new(&U448::from_be_hex(
230299
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000098aa",
@@ -316,6 +385,12 @@ impl FieldElement {
316385
Self(ConstMontyType::new(&U448::from_le_slice(bytes)))
317386
}
318387

388+
pub fn from_repr(bytes: &[u8; 56]) -> CtOption<Self> {
389+
let integer = U448::from_le_slice(bytes);
390+
let is_some = integer.ct_lt(MODULUS::PARAMS.modulus());
391+
CtOption::new(Self(ConstMontyType::new(&integer)), is_some)
392+
}
393+
319394
pub fn double(&self) -> Self {
320395
Self(self.0.double())
321396
}

0 commit comments

Comments
 (0)