Skip to content

Commit 5f27bb4

Browse files
authored
primeorder: add optimized implementation for LinearCombination (#1360)
1 parent 7f9b341 commit 5f27bb4

File tree

2 files changed

+125
-45
lines changed

2 files changed

+125
-45
lines changed

p256/tests/projective.rs

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
#![cfg(all(feature = "arithmetic", feature = "test-vectors"))]
44

55
use elliptic_curve::{
6-
BatchNormalize,
6+
BatchNormalize, Group,
7+
array::Array,
78
group::{GroupEncoding, ff::PrimeField},
8-
ops::ReduceNonZero,
9+
ops::{LinearCombination, Reduce, ReduceNonZero},
910
point::NonIdentity,
1011
sec1::{self, ToEncodedPoint},
1112
};
1213
use p256::{
1314
AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar,
1415
test_vectors::group::{ADD_TEST_VECTORS, MUL_TEST_VECTORS},
1516
};
16-
use primeorder::{Double, test_projective_arithmetic};
17+
use primeorder::test_projective_arithmetic;
1718
use proptest::{prelude::any, prop_compose, proptest};
1819

1920
test_projective_arithmetic!(
@@ -36,6 +37,18 @@ prop_compose! {
3637
}
3738
}
3839

40+
prop_compose! {
41+
fn projective()(bytes in any::<[u8; 32]>()) -> ProjectivePoint {
42+
ProjectivePoint::mul_by_generator(&Scalar::reduce(&Array::from(bytes)))
43+
}
44+
}
45+
46+
prop_compose! {
47+
fn scalar()(bytes in any::<[u8; 32]>()) -> Scalar {
48+
Scalar::reduce(&Array::from(bytes))
49+
}
50+
}
51+
3952
// TODO: move to `primeorder::test_projective_arithmetic`.
4053
proptest! {
4154
#[test]
@@ -66,4 +79,29 @@ proptest! {
6679
assert_eq!(affine_point, point.to_affine());
6780
}
6881
}
82+
83+
#[test]
84+
fn lincomb(
85+
p1 in projective(),
86+
s1 in scalar(),
87+
p2 in projective(),
88+
s2 in scalar(),
89+
) {
90+
let reference = p1 * s1 + p2 * s2;
91+
let test = ProjectivePoint::lincomb(&[(p1, s1), (p2, s2)]);
92+
assert_eq!(reference, test);
93+
}
94+
95+
#[test]
96+
#[cfg(feature = "alloc")]
97+
fn lincomb_alloc(
98+
p1 in projective(),
99+
s1 in scalar(),
100+
p2 in projective(),
101+
s2 in scalar(),
102+
) {
103+
let reference = p1 * s1 + p2 * s2;
104+
let test = ProjectivePoint::lincomb(vec![(p1, s1), (p2, s2)].as_slice());
105+
assert_eq!(reference, test);
106+
}
69107
}

primeorder/src/projective.rs

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
use crate::{AffinePoint, Field, PrimeCurveParams, point_arithmetic::PointArithmetic};
66
use core::{
7+
array,
78
borrow::Borrow,
89
iter::Sum,
910
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
@@ -12,7 +13,7 @@ use elliptic_curve::{
1213
BatchNormalize, CurveGroup, Error, FieldBytes, FieldBytesSize, PrimeField, PublicKey, Result,
1314
Scalar,
1415
array::ArraySize,
15-
bigint::ArrayEncoding,
16+
bigint::{ArrayEncoding, ByteArray},
1617
group::{
1718
Group, GroupEncoding,
1819
prime::{PrimeCurve, PrimeGroup},
@@ -109,46 +110,10 @@ where
109110
where
110111
Self: Double,
111112
{
112-
let k = Into::<C::Uint>::into(*k).to_le_byte_array();
113+
let mut k = Into::<C::Uint>::into(*k).to_le_byte_array();
114+
let mut pc = LookupTable::new(*self);
113115

114-
let mut pc = [Self::default(); 16];
115-
pc[0] = Self::IDENTITY;
116-
pc[1] = *self;
117-
118-
for i in 2..16 {
119-
pc[i] = if i % 2 == 0 {
120-
Double::double(&pc[i / 2])
121-
} else {
122-
pc[i - 1].add(self)
123-
};
124-
}
125-
126-
let mut q = Self::IDENTITY;
127-
let mut pos = (<Scalar<C> as PrimeField>::NUM_BITS.div_ceil(8) * 8) as usize - 4;
128-
129-
loop {
130-
let slot = (k[pos >> 3] >> (pos & 7)) & 0xf;
131-
132-
let mut t = ProjectivePoint::IDENTITY;
133-
134-
for i in 1..16 {
135-
t.conditional_assign(
136-
&pc[i],
137-
Choice::from(((slot as usize ^ i).wrapping_sub(1) >> 8) as u8 & 1),
138-
);
139-
}
140-
141-
q = q.add(&t);
142-
143-
if pos == 0 {
144-
break;
145-
}
146-
147-
q = Double::double(&Double::double(&Double::double(&Double::double(&q))));
148-
pos -= 4;
149-
}
150-
151-
q
116+
lincomb(array::from_mut(&mut k), array::from_mut(&mut pc))
152117
}
153118
}
154119

@@ -403,15 +368,92 @@ where
403368
C: PrimeCurveParams,
404369
FieldBytes<C>: Copy,
405370
{
406-
// TODO(tarcieri): optimized implementation
371+
#[cfg(feature = "alloc")]
372+
fn lincomb(points_and_scalars: &[(Self, Scalar<C>)]) -> Self {
373+
let (mut ks, mut pcs): (Vec<_>, Vec<_>) = points_and_scalars
374+
.iter()
375+
.map(|(point, scalar)| {
376+
(
377+
Into::<C::Uint>::into(*scalar).to_le_byte_array(),
378+
LookupTable::new(*point),
379+
)
380+
})
381+
.unzip();
382+
383+
lincomb::<C>(&mut ks, &mut pcs)
384+
}
407385
}
408386

409387
impl<C, const N: usize> LinearCombination<[(Self, Scalar<C>); N]> for ProjectivePoint<C>
410388
where
411389
C: PrimeCurveParams,
412390
FieldBytes<C>: Copy,
413391
{
414-
// TODO(tarcieri): optimized implementation
392+
fn lincomb(points_and_scalars: &[(Self, Scalar<C>); N]) -> Self {
393+
let mut ks: [_; N] = array::from_fn(|index| {
394+
Into::<C::Uint>::into(points_and_scalars[index].1).to_le_byte_array()
395+
});
396+
let mut pcs: [_; N] = array::from_fn(|index| LookupTable::new(points_and_scalars[index].0));
397+
398+
lincomb::<C>(&mut ks, &mut pcs)
399+
}
400+
}
401+
402+
fn lincomb<C: PrimeCurveParams>(
403+
ks: &mut [ByteArray<C::Uint>],
404+
pcs: &mut [LookupTable<C>],
405+
) -> ProjectivePoint<C> {
406+
let mut q = ProjectivePoint::IDENTITY;
407+
let mut pos = (<Scalar<C> as PrimeField>::NUM_BITS.div_ceil(8) * 8) as usize - 4;
408+
409+
loop {
410+
for (k, pc) in ks.iter().zip(pcs.iter()) {
411+
let slot = (k[pos >> 3] >> (pos & 7)) & 0xf;
412+
q = q.add(&pc.select(slot));
413+
}
414+
415+
if pos == 0 {
416+
break;
417+
}
418+
419+
q = Double::double(&Double::double(&Double::double(&Double::double(&q))));
420+
pos -= 4;
421+
}
422+
423+
q
424+
}
425+
426+
struct LookupTable<C: PrimeCurveParams>([ProjectivePoint<C>; 16]);
427+
428+
impl<C: PrimeCurveParams> LookupTable<C> {
429+
fn new(point: ProjectivePoint<C>) -> Self {
430+
let mut pc = [ProjectivePoint::default(); 16];
431+
pc[0] = ProjectivePoint::IDENTITY;
432+
pc[1] = point;
433+
434+
for i in 2..16 {
435+
pc[i] = if i % 2 == 0 {
436+
Double::double(&pc[i / 2])
437+
} else {
438+
pc[i - 1].add(point)
439+
};
440+
}
441+
442+
Self(pc)
443+
}
444+
445+
fn select(&self, slot: u8) -> ProjectivePoint<C> {
446+
let mut t = ProjectivePoint::IDENTITY;
447+
448+
for i in 1..16 {
449+
t.conditional_assign(
450+
&self.0[i],
451+
Choice::from(((slot as usize ^ i).wrapping_sub(1) >> 8) as u8 & 1),
452+
);
453+
}
454+
455+
t
456+
}
415457
}
416458

417459
impl<C> PrimeGroup for ProjectivePoint<C>

0 commit comments

Comments
 (0)