Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion curve25519-dalek/benches/dalek_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ mod multiscalar_benches {
// rerandomize the scalars for every call just in case.
b.iter_batched(
|| construct_scalars(size),
|scalars| EdwardsPoint::multiscalar_mul(&scalars, &points),
|scalars| EdwardsPoint::multiscalar_mul_alloc(points.iter().zip(scalars)),
BatchSize::SmallInput,
);
},
Expand Down
41 changes: 30 additions & 11 deletions curve25519-dalek/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

use crate::EdwardsPoint;
use crate::Scalar;
use crate::traits::MultiscalarMul;

pub mod serial;

Expand Down Expand Up @@ -191,30 +192,48 @@ impl VartimePrecomputedStraus {
}
}

#[allow(missing_docs)]
pub fn straus_multiscalar_mul<const N: usize>(
points_and_scalars: &[(EdwardsPoint, Scalar); N],
) -> EdwardsPoint {
match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
BackendKind::Avx2 => {
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul(points_and_scalars)
}
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul(
scalars, points,
)
}
BackendKind::Serial => {
serial::scalar_mul::straus::Straus::multiscalar_mul(points_and_scalars)
}
}
}

#[allow(missing_docs)]
#[cfg(feature = "alloc")]
pub fn straus_multiscalar_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
pub fn straus_multiscalar_mul_alloc<I, P, S>(points_and_scalars: I) -> EdwardsPoint
where
I: IntoIterator,
I::Item: core::borrow::Borrow<Scalar>,
J: IntoIterator,
J::Item: core::borrow::Borrow<EdwardsPoint>,
I: IntoIterator<Item = (P, S)>,
P: core::borrow::Borrow<EdwardsPoint>,
S: core::borrow::Borrow<Scalar>,
{
use crate::traits::MultiscalarMul;

match get_selected_backend() {
#[cfg(curve25519_dalek_backend = "simd")]
BackendKind::Avx2 => {
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::<I, J>(scalars, points)
vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul_alloc(points_and_scalars)
}
#[cfg(all(curve25519_dalek_backend = "unstable_avx512", nightly))]
BackendKind::Avx512 => {
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::<I, J>(
scalars, points,
vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul_alloc(
points_and_scalars,
)
}
BackendKind::Serial => {
serial::scalar_mul::straus::Straus::multiscalar_mul::<I, J>(scalars, points)
serial::scalar_mul::straus::Straus::multiscalar_mul_alloc(points_and_scalars)
}
}
}
Expand Down
1 change: 0 additions & 1 deletion curve25519-dalek/src/backend/serial/scalar_mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pub mod variable_base;
#[allow(missing_docs)]
pub mod vartime_double_base;

#[cfg(feature = "alloc")]
pub mod straus;

#[cfg(feature = "alloc")]
Expand Down
187 changes: 106 additions & 81 deletions curve25519-dalek/src/backend/serial/scalar_mul/straus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@

#![allow(non_snake_case)]

#[cfg(feature = "alloc")]
use alloc::vec::Vec;

#[cfg(feature = "alloc")]
use core::borrow::Borrow;
use core::cmp::Ordering;

use crate::backend::serial::curve_models::ProjectiveNielsPoint;
use crate::edwards::EdwardsPoint;
use crate::scalar::Scalar;
use crate::traits::Identity;
use crate::traits::MultiscalarMul;
#[cfg(feature = "alloc")]
use crate::traits::VartimeMultiscalarMul;
use crate::window::LookupTable;

/// Perform multiscalar multiplication by the interleaved window
/// method, also known as Straus' method (since it was apparently
Expand Down Expand Up @@ -49,101 +54,120 @@ pub struct Straus {}
impl MultiscalarMul for Straus {
type Point = EdwardsPoint;

/// Constant-time Straus using a fixed window of size \\(4\\).
///
/// Our goal is to compute
/// \\[
/// Q = s_1 P_1 + \cdots + s_n P_n.
/// \\]
///
/// For each point \\( P_i \\), precompute a lookup table of
/// \\[
/// P_i, 2P_i, 3P_i, 4P_i, 5P_i, 6P_i, 7P_i, 8P_i.
/// \\]
///
/// For each scalar \\( s_i \\), compute its radix-\\(2^4\\)
/// signed digits \\( s_{i,j} \\), i.e.,
/// \\[
/// s_i = s_{i,0} + s_{i,1} 16^1 + ... + s_{i,63} 16^{63},
/// \\]
/// with \\( -8 \leq s_{i,j} < 8 \\). Since \\( 0 \leq |s_{i,j}|
/// \leq 8 \\), we can retrieve \\( s_{i,j} P_i \\) from the
/// lookup table with a conditional negation: using signed
/// digits halves the required table size.
///
/// Then as in the single-base fixed window case, we have
/// \\[
/// \begin{aligned}
/// s_i P_i &= P_i (s_{i,0} + s_{i,1} 16^1 + \cdots + s_{i,63} 16^{63}) \\\\
/// s_i P_i &= P_i s_{i,0} + P_i s_{i,1} 16^1 + \cdots + P_i s_{i,63} 16^{63} \\\\
/// s_i P_i &= P_i s_{i,0} + 16(P_i s_{i,1} + 16( \cdots +16P_i s_{i,63})\cdots )
/// \end{aligned}
/// \\]
/// so each \\( s_i P_i \\) can be computed by alternately adding
/// a precomputed multiple \\( P_i s_{i,j} \\) of \\( P_i \\) and
/// repeatedly doubling.
///
/// Now consider the two-dimensional sum
/// \\[
/// \begin{aligned}
/// s\_1 P\_1 &=& P\_1 s\_{1,0} &+& 16 (P\_1 s\_{1,1} &+& 16 ( \cdots &+& 16 P\_1 s\_{1,63}&) \cdots ) \\\\
/// + & & + & & + & & & & + & \\\\
/// s\_2 P\_2 &=& P\_2 s\_{2,0} &+& 16 (P\_2 s\_{2,1} &+& 16 ( \cdots &+& 16 P\_2 s\_{2,63}&) \cdots ) \\\\
/// + & & + & & + & & & & + & \\\\
/// \vdots & & \vdots & & \vdots & & & & \vdots & \\\\
/// + & & + & & + & & & & + & \\\\
/// s\_n P\_n &=& P\_n s\_{n,0} &+& 16 (P\_n s\_{n,1} &+& 16 ( \cdots &+& 16 P\_n s\_{n,63}&) \cdots )
/// \end{aligned}
/// \\]
/// The sum of the left-hand column is the result \\( Q \\); by
/// computing the two-dimensional sum on the right column-wise,
/// top-to-bottom, then right-to-left, we need to multiply by \\(
/// 16\\) only once per column, sharing the doublings across all
/// of the input points.
fn multiscalar_mul<I, J>(scalars: I, points: J) -> EdwardsPoint
where
I: IntoIterator,
I::Item: Borrow<Scalar>,
J: IntoIterator,
J::Item: Borrow<EdwardsPoint>,
{
use crate::backend::serial::curve_models::ProjectiveNielsPoint;
use crate::traits::Identity;
use crate::window::LookupTable;
fn multiscalar_mul<const N: usize>(
points_and_scalars: &[(EdwardsPoint, Scalar); N],
) -> EdwardsPoint {
let lookup_tables: [_; N] = core::array::from_fn(|index| {
LookupTable::<ProjectiveNielsPoint>::from(&points_and_scalars[index].0)
});

let lookup_tables: Vec<_> = points
.into_iter()
.map(|point| LookupTable::<ProjectiveNielsPoint>::from(point.borrow()))
.collect();
let scalar_digits: [_; N] =
core::array::from_fn(|index| points_and_scalars[index].1.as_radix_16());

multiscalar_mul(&scalar_digits, &lookup_tables)
}

#[cfg(feature = "alloc")]
fn multiscalar_mul_alloc<I, P, S>(points_and_scalars: I) -> EdwardsPoint
where
I: IntoIterator<Item = (P, S)>,
P: Borrow<EdwardsPoint>,
S: Borrow<Scalar>,
{
// This puts the scalar digits into a heap-allocated Vec.
// To ensure that these are erased, pass ownership of the Vec into a
// Zeroizing wrapper.
#[cfg_attr(not(feature = "zeroize"), allow(unused_mut))]
let mut scalar_digits: Vec<_> = scalars
let (lookup_tables, mut scalar_digits): (Vec<_>, Vec<_>) = points_and_scalars
.into_iter()
.map(|s| s.borrow().as_radix_16())
.collect();
.map(|(p, s)| {
(
LookupTable::<ProjectiveNielsPoint>::from(p.borrow()),
s.borrow().as_radix_16(),
)
})
.unzip();

let mut Q = EdwardsPoint::identity();
for j in (0..64).rev() {
Q = Q.mul_by_pow_2(4);
let it = scalar_digits.iter().zip(lookup_tables.iter());
for (s_i, lookup_table_i) in it {
// R_i = s_{i,j} * P_i
let R_i = lookup_table_i.select(s_i[j]);
// Q = Q + R_i
Q = (&Q + &R_i).as_extended();
}
}
let Q = multiscalar_mul(&scalar_digits, &lookup_tables);

#[cfg(feature = "zeroize")]
zeroize::Zeroize::zeroize(&mut scalar_digits);
zeroize::Zeroize::zeroize(&mut scalar_digits.iter_mut());

Q
}
}

/// Constant-time Straus using a fixed window of size \\(4\\).
///
/// Our goal is to compute
/// \\[
/// Q = s_1 P_1 + \cdots + s_n P_n.
/// \\]
///
/// For each point \\( P_i \\), precompute a lookup table of
/// \\[
/// P_i, 2P_i, 3P_i, 4P_i, 5P_i, 6P_i, 7P_i, 8P_i.
/// \\]
///
/// For each scalar \\( s_i \\), compute its radix-\\(2^4\\)
/// signed digits \\( s_{i,j} \\), i.e.,
/// \\[
/// s_i = s_{i,0} + s_{i,1} 16^1 + ... + s_{i,63} 16^{63},
/// \\]
/// with \\( -8 \leq s_{i,j} < 8 \\). Since \\( 0 \leq |s_{i,j}|
/// \leq 8 \\), we can retrieve \\( s_{i,j} P_i \\) from the
/// lookup table with a conditional negation: using signed
/// digits halves the required table size.
///
/// Then as in the single-base fixed window case, we have
/// \\[
/// \begin{aligned}
/// s_i P_i &= P_i (s_{i,0} + s_{i,1} 16^1 + \cdots + s_{i,63} 16^{63}) \\\\
/// s_i P_i &= P_i s_{i,0} + P_i s_{i,1} 16^1 + \cdots + P_i s_{i,63} 16^{63} \\\\
/// s_i P_i &= P_i s_{i,0} + 16(P_i s_{i,1} + 16( \cdots +16P_i s_{i,63})\cdots )
/// \end{aligned}
/// \\]
/// so each \\( s_i P_i \\) can be computed by alternately adding
/// a precomputed multiple \\( P_i s_{i,j} \\) of \\( P_i \\) and
/// repeatedly doubling.
///
/// Now consider the two-dimensional sum
/// \\[
/// \begin{aligned}
/// s\_1 P\_1 &=& P\_1 s\_{1,0} &+& 16 (P\_1 s\_{1,1} &+& 16 ( \cdots &+& 16 P\_1 s\_{1,63}&) \cdots ) \\\\
/// + & & + & & + & & & & + & \\\\
/// s\_2 P\_2 &=& P\_2 s\_{2,0} &+& 16 (P\_2 s\_{2,1} &+& 16 ( \cdots &+& 16 P\_2 s\_{2,63}&) \cdots ) \\\\
/// + & & + & & + & & & & + & \\\\
/// \vdots & & \vdots & & \vdots & & & & \vdots & \\\\
/// + & & + & & + & & & & + & \\\\
/// s\_n P\_n &=& P\_n s\_{n,0} &+& 16 (P\_n s\_{n,1} &+& 16 ( \cdots &+& 16 P\_n s\_{n,63}&) \cdots )
/// \end{aligned}
/// \\]
/// The sum of the left-hand column is the result \\( Q \\); by
/// computing the two-dimensional sum on the right column-wise,
/// top-to-bottom, then right-to-left, we need to multiply by \\(
/// 16\\) only once per column, sharing the doublings across all
/// of the input points.
fn multiscalar_mul(
scalar_digits: &[[i8; 64]],
lookup_tables: &[LookupTable<ProjectiveNielsPoint>],
) -> EdwardsPoint {
let mut Q = EdwardsPoint::identity();
for j in (0..64).rev() {
Q = Q.mul_by_pow_2(4);
let it = scalar_digits.iter().zip(lookup_tables.iter());
for (s_i, lookup_table_i) in it {
// R_i = s_{i,j} * P_i
let R_i = lookup_table_i.select(s_i[j]);
// Q = Q + R_i
Q = (&Q + &R_i).as_extended();
}
}

Q
}

#[cfg(feature = "alloc")]
impl VartimeMultiscalarMul for Straus {
type Point = EdwardsPoint;

Expand All @@ -167,6 +191,7 @@ impl VartimeMultiscalarMul for Straus {
};
use crate::traits::Identity;
use crate::window::NafLookupTable5;
use core::cmp::Ordering;

let nafs: Vec<_> = scalars
.into_iter()
Expand Down
1 change: 0 additions & 1 deletion curve25519-dalek/src/backend/vector/scalar_mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ pub mod variable_base;
pub mod vartime_double_base;

#[allow(missing_docs)]
#[cfg(feature = "alloc")]
pub mod straus;

#[allow(missing_docs)]
Expand Down
Loading