Skip to content

Commit 65462c8

Browse files
authored
Merge pull request #3086 from ljedrz/perf/faster_sum_of_products
[Perf] Speed up sum_of_products
2 parents 7a44df8 + 45ef56e commit 65462c8

File tree

9 files changed

+19
-28
lines changed

9 files changed

+19
-28
lines changed

algorithms/src/crypto_hash/poseidon.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,9 @@ impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
233233
#[inline]
234234
fn apply_mds(&mut self) {
235235
let mut new_state = State::default();
236+
let curr_state: Vec<F> = self.state.iter().copied().collect::<Vec<_>>();
236237
new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
237-
*new_elem = F::sum_of_products(self.state.iter(), mds_row.iter());
238+
*new_elem = F::sum_of_products(&curr_state, mds_row);
238239
});
239240
self.state = new_state;
240241
}

console/algorithms/src/poseidon/helpers/sponge.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ impl<E: Environment, const RATE: usize, const CAPACITY: usize> PoseidonSponge<E,
126126
#[inline]
127127
fn apply_mds(&mut self) {
128128
let mut new_state = State::default();
129+
let curr_state: Vec<<E as Environment>::Field> = self.state.iter().map(|e| *e.deref()).collect::<Vec<_>>();
129130
new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
130-
*new_elem = Field::new(E::Field::sum_of_products(self.state.iter().map(|e| e.deref()), mds_row.iter()));
131+
*new_elem = Field::new(E::Field::sum_of_products(curr_state.as_slice(), mds_row));
131132
});
132133
self.state = new_state;
133134
}

curves/src/bls12_377/tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ fn test_fr_sum_of_products() {
182182
for i in [2, 4, 8, 16, 32] {
183183
let a = (0..i).map(|_| rng.r#gen()).collect::<Vec<_>>();
184184
let b = (0..i).map(|_| rng.r#gen()).collect::<Vec<_>>();
185-
assert_eq!(Fr::sum_of_products(a.iter(), b.iter()), a.into_iter().zip(b).map(|(a, b)| a * b).sum());
185+
assert_eq!(Fr::sum_of_products(&a, &b), a.into_iter().zip(b).map(|(a, b)| a * b).sum());
186186
}
187187
}
188188

@@ -192,7 +192,7 @@ fn test_fq_sum_of_products() {
192192
for i in [2, 4, 8, 16, 32] {
193193
let a = (0..i).map(|_| rng.r#gen()).collect::<Vec<_>>();
194194
let b = (0..i).map(|_| rng.r#gen()).collect::<Vec<_>>();
195-
assert_eq!(Fq::sum_of_products(a.iter(), b.iter()), a.into_iter().zip(b).map(|(a, b)| a * b).sum());
195+
assert_eq!(Fq::sum_of_products(&a, &b), a.into_iter().zip(b).map(|(a, b)| a * b).sum());
196196
}
197197
}
198198

curves/src/templates/short_weierstrass_jacobian/projective.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ impl<P: Parameters> ProjectiveCurve for Projective<P> {
278278
self.x -= &v.double();
279279

280280
// Y3 = r*(V-X3)-2*Y1*J
281-
self.y = P::BaseField::sum_of_products([r, -self.y.double()].iter(), [(v - self.x), j].iter());
281+
self.y = P::BaseField::sum_of_products(&[r, -self.y.double()], &[(v - self.x), j]);
282282

283283
// Z3 = (Z1+H)^2-Z1Z1-HH
284284
self.z += &h;
@@ -457,7 +457,7 @@ impl<'a, P: Parameters> AddAssign<&'a Self> for Projective<P> {
457457
self.x = r.square() - j - (v.double());
458458

459459
// Y3 = r*(V - X3) - 2*S1*J
460-
self.y = P::BaseField::sum_of_products([r, -s1.double()].iter(), [(v - self.x), j].iter());
460+
self.y = P::BaseField::sum_of_products(&[r, -s1.double()], &[(v - self.x), j]);
461461

462462
// Z3 = ((Z1+Z2)^2 - Z1Z1 - Z2Z2)*H
463463
self.z = ((self.z + other.z).square() - z1z1 - z2z2) * h;

curves/src/traits/tests_field.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ pub fn field_test<F: Field>(a: F, b: F, rng: &mut TestRng) {
495495
for _ in 0..len {
496496
a.push(F::rand(rng));
497497
b.push(F::rand(rng));
498-
assert_eq!(F::sum_of_products(a.iter(), b.iter()), a.iter().zip(b.iter()).map(|(x, y)| *x * y).sum());
498+
assert_eq!(F::sum_of_products(&a, &b), a.iter().zip(b.iter()).map(|(x, y)| *x * y).sum());
499499
}
500500
}
501501

fields/src/fp2.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,8 @@ impl<P: Fp2Parameters> MulAssign<&'_ Self> for Fp2<P> {
402402
#[allow(clippy::suspicious_op_assign_impl)]
403403
fn mul_assign(&mut self, other: &Self) {
404404
*self = Self::new(
405-
P::Fp::sum_of_products([self.c0, P::mul_fp_by_nonresidue(&self.c1)].iter(), [other.c0, other.c1].iter()),
406-
P::Fp::sum_of_products([self.c0, self.c1].iter(), [other.c1, other.c0].iter()),
405+
P::Fp::sum_of_products(&[self.c0, P::mul_fp_by_nonresidue(&self.c1)], &[other.c0, other.c1]),
406+
P::Fp::sum_of_products(&[self.c0, self.c1], &[other.c1, other.c0]),
407407
)
408408
}
409409
}

fields/src/fp_256.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,7 @@ impl<P: Fp256Parameters> Field for Fp256<P> {
164164
Self::from_bigint(two_inv).unwrap() // Guaranteed to be valid.
165165
}
166166

167-
fn sum_of_products<'a>(
168-
a: impl Iterator<Item = &'a Self> + Clone,
169-
b: impl Iterator<Item = &'a Self> + Clone,
170-
) -> Self {
167+
fn sum_of_products<'a>(a: &'a [Self], b: &'a [Self]) -> Self {
171168
// For a single `a x b` multiplication, operand scanning (schoolbook) takes each
172169
// limb of `a` in turn, and multiplies it by all of the limbs of `b` to compute
173170
// the result as a double-width intermediate representation, which is then fully
@@ -188,7 +185,7 @@ impl<P: Fp256Parameters> Field for Fp256<P> {
188185
// Algorithm 2, line 3
189186
// For each pair in the overall sum of products:
190187
let (t0, t1, t2, t3, mut t4) =
191-
a.clone().zip(b.clone()).fold((u0, u1, u2, u3, 0), |(t0, t1, t2, t3, mut t4), (a, b)| {
188+
a.iter().zip(b).fold((u0, u1, u2, u3, 0), |(t0, t1, t2, t3, mut t4), (a, b)| {
192189
// Compute digit_j x row and accumulate into `u`.
193190
let mut carry = 0;
194191
let t0 = fa::mac_with_carry(t0, a.0.0[j], b.0.0[0], &mut carry);

fields/src/fp_384.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,7 @@ impl<P: Fp384Parameters> Field for Fp384<P> {
197197
Self::from_bigint(two_inv).unwrap() // Guaranteed to be valid.
198198
}
199199

200-
fn sum_of_products<'a>(
201-
a: impl Iterator<Item = &'a Self> + Clone,
202-
b: impl Iterator<Item = &'a Self> + Clone,
203-
) -> Self {
200+
fn sum_of_products<'a>(a: &'a [Self], b: &'a [Self]) -> Self {
204201
// For a single `a x b` multiplication, operand scanning (schoolbook) takes each
205202
// limb of `a` in turn, and multiplies it by all of the limbs of `b` to compute
206203
// the result as a double-width intermediate representation, which is then fully
@@ -220,9 +217,8 @@ impl<P: Fp384Parameters> Field for Fp384<P> {
220217
let (u0, u1, u2, u3, u4, u5) = (0..6).fold((0, 0, 0, 0, 0, 0), |(u0, u1, u2, u3, u4, u5), j| {
221218
// Algorithm 2, line 3
222219
// For each pair in the overall sum of products:
223-
let (t0, t1, t2, t3, t4, t5, mut t6) = a.clone().zip(b.clone()).fold(
224-
(u0, u1, u2, u3, u4, u5, 0),
225-
|(t0, t1, t2, t3, t4, t5, mut t6), (a, b)| {
220+
let (t0, t1, t2, t3, t4, t5, mut t6) =
221+
a.iter().zip(b).fold((u0, u1, u2, u3, u4, u5, 0), |(t0, t1, t2, t3, t4, t5, mut t6), (a, b)| {
226222
// Compute digit_j x row and accumulate into `u`.
227223
let mut carry = 0;
228224
let t0 = fa::mac_with_carry(t0, a.0.0[j], b.0.0[0], &mut carry);
@@ -234,8 +230,7 @@ impl<P: Fp384Parameters> Field for Fp384<P> {
234230
let _ = fa::adc(&mut t6, 0, carry);
235231

236232
(t0, t1, t2, t3, t4, t5, t6)
237-
},
238-
);
233+
});
239234

240235
// Algorithm 2, lines 4-5
241236
// This is a single step of the usual Montgomery reduction process.

fields/src/traits/field.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,8 @@ pub trait Field:
119119
/// Squares `self` in place.
120120
fn square_in_place(&mut self) -> &mut Self;
121121

122-
fn sum_of_products<'a>(
123-
a: impl Iterator<Item = &'a Self> + Clone,
124-
b: impl Iterator<Item = &'a Self> + Clone,
125-
) -> Self {
126-
a.zip(b).map(|(a, b)| *a * b).sum::<Self>()
122+
fn sum_of_products<'a>(a: &'a [Self], b: &'a [Self]) -> Self {
123+
a.iter().zip(b).map(|(a, b)| *a * b).sum::<Self>()
127124
}
128125

129126
/// Computes the multiplicative inverse of `self` if `self` is nonzero.

0 commit comments

Comments
 (0)