Skip to content

Commit 94224cc

Browse files
committed
Optimize the lagrange_interpolation
1 parent 73e3d31 commit 94224cc

File tree

4 files changed

+117
-56
lines changed

4 files changed

+117
-56
lines changed

crates/dkg/src/dkg_math.rs

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ pub fn lagrange_interpolation<C: Curve>(
204204
"zero secret share id",
205205
)));
206206
}
207-
let mut r = C::Point::identity();
207+
// Pre-allocate vectors for batch processing
208+
let mut terms = Vec::with_capacity(k);
209+
208210
for i in 0..k {
209211
let mut b = x_vec[i];
210212
for j in 0..k {
@@ -220,23 +222,34 @@ pub fn lagrange_interpolation<C: Curve>(
220222
}
221223
}
222224
let li0 = a.mul(&b.invert());
223-
let tmp = y_vec[i].mul_scalar(&li0);
224-
r = r.add(&tmp);
225+
terms.push(y_vec[i].mul_scalar(&li0));
225226
}
226-
Ok(r)
227+
228+
// Batch add all terms
229+
Ok(batch_add_points::<C>(&terms))
227230
}
228231

229232
#[allow(clippy::assign_op_pattern)]
230233
pub fn agg_coefficients<C: Curve>(
231234
verification_vectors: &[Vec<C::Point>],
232235
ids: &[C::Scalar],
233236
) -> Vec<C::Point> {
234-
let mut final_cfs = Vec::new();
235-
for i in 0..verification_vectors[0].len() {
236-
let mut sum = C::Point::identity();
237+
let num_vectors = verification_vectors.len();
238+
let vector_len = verification_vectors[0].len();
239+
240+
// Pre-allocate the result vector
241+
let mut final_cfs = Vec::with_capacity(vector_len);
242+
243+
// Batch point additions for better cache locality and performance
244+
for i in 0..vector_len {
245+
// Collect all points at position i for batch addition
246+
let mut points_to_sum = Vec::with_capacity(num_vectors);
237247
for v in verification_vectors {
238-
sum = sum.add(&v[i]);
248+
points_to_sum.push(v[i]);
239249
}
250+
251+
// Perform batched addition
252+
let sum = batch_add_points::<C>(&points_to_sum);
240253
final_cfs.push(sum);
241254
}
242255
let mut final_keys = Vec::new();
@@ -247,6 +260,48 @@ pub fn agg_coefficients<C: Curve>(
247260
final_keys
248261
}
249262

263+
// Optimized batch point addition function for elliptic curve operations
264+
pub fn batch_add_points<C: Curve>(points: &[C::Point]) -> C::Point {
265+
if points.is_empty() {
266+
return C::Point::identity();
267+
}
268+
269+
if points.len() == 1 {
270+
return points[0];
271+
}
272+
273+
// For small numbers of points, use sequential addition
274+
if points.len() <= 4 {
275+
let mut sum = points[0];
276+
for point in &points[1..] {
277+
sum = sum.add(point);
278+
}
279+
return sum;
280+
}
281+
282+
// For larger numbers, use a binary tree approach to reduce depth
283+
let mut current_points = points.to_vec();
284+
285+
while current_points.len() > 1 {
286+
let mut next_level = Vec::new();
287+
let mut i = 0;
288+
while i < current_points.len() {
289+
if i + 1 < current_points.len() {
290+
// Add pairs
291+
next_level.push(current_points[i].add(&current_points[i + 1]));
292+
i += 2;
293+
} else {
294+
// Handle odd element
295+
next_level.push(current_points[i]);
296+
i += 1;
297+
}
298+
}
299+
current_points = next_level;
300+
}
301+
302+
current_points[0]
303+
}
304+
250305
#[cfg(test)]
251306
mod tests {
252307
use crate::crypto::*;

crates/dkg/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ pub use verification::{
99
};
1010

1111
pub use crypto::*;
12+
pub use dkg_math::batch_add_points;
1213
pub use types::*;

crates/finalization_prove/src/main.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ where
8282
// Method 1: Compute P(0) directly as the constant term of aggregated polynomial
8383
// P(x) = Σ_{j=0}^t c_j x^j where c_j = Σ_{k=1}^n PK(a_{k,j})
8484
// P(0) = c_0 = Σ_{k=1}^n PK(a_{k,0})
85-
let mut p0 = Setup::Point::identity();
86-
for verification_vector in &verification_vectors {
87-
p0 = p0.add(&verification_vector[0]); // Sum of constant terms
88-
}
85+
86+
// Extract constant terms for batch addition
87+
let constant_terms: Vec<Setup::Point> = verification_vectors.iter().map(|v| v[0]).collect();
88+
89+
// Use optimized batch addition
90+
let p0 = dkg::batch_add_points::<Setup::Curve>(&constant_terms);
8991

9092
// Method 2: Use Lagrange interpolation on partial public keys (spec requirement)
9193
// L(PK_1, ..., PK_n) should equal P(0)
@@ -119,7 +121,9 @@ fn lagrange_interpolation_at_zero<C: dkg::Curve>(
119121
// For each point (x_i, y_i), compute the Lagrange basis polynomial l_i(0)
120122
// l_i(0) = Π_{j≠i} (0 - x_j) / (x_i - x_j) = Π_{j≠i} (-x_j) / (x_i - x_j)
121123

122-
let mut result = C::Point::identity();
124+
// Pre-allocate vectors for batch processing
125+
let mut terms = Vec::with_capacity(k);
126+
123127
for i in 0..k {
124128
let mut numerator = C::Scalar::from_u32(1); // This will be Π_{j≠i} (-x_j)
125129
let mut denominator = C::Scalar::from_u32(1); // This will be Π_{j≠i} (x_i - x_j)
@@ -138,8 +142,9 @@ fn lagrange_interpolation_at_zero<C: dkg::Curve>(
138142

139143
let li0 = numerator.mul(&denominator.invert());
140144
let term = y_vec[i].mul_scalar(&li0);
141-
result = result.add(&term);
145+
terms.push(term);
142146
}
143147

144-
Ok(result)
148+
// Batch add all terms instead of accumulating sequentially
149+
Ok(dkg::batch_add_points::<C>(&terms))
145150
}

perf/baseline.json

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -101,29 +101,29 @@
101101
}
102102
},
103103
"bad-partial-key_bad_partial_key.json": {
104-
"total_opcodes": 25815757,
104+
"total_opcodes": 25814891,
105105
"opcodes": {
106-
"add": 9927990,
106+
"add": 9927134,
107107
"sltu": 4806127,
108-
"lw": 2914798,
108+
"lw": 2914792,
109109
"mul": 1952285,
110110
"mulhu": 1896338,
111-
"sw": 1577033,
112-
"beq": 1080590,
111+
"sw": 1577028,
112+
"beq": 1080593,
113113
"and": 764557,
114114
"or": 233881,
115115
"sub": 174034,
116116
"sll": 64726,
117117
"srl": 64259,
118118
"sra": 61131,
119-
"jalr": 59259,
119+
"jalr": 59257,
120120
"bltu": 58140,
121-
"bne": 50799,
121+
"bne": 50796,
122122
"xor": 45900,
123-
"auipc": 30228,
123+
"auipc": 30227,
124124
"lbu": 16666,
125125
"sb": 16076,
126-
"jal": 10682,
126+
"jal": 10686,
127127
"bgeu": 5583,
128128
"ecall": 1565,
129129
"blt": 1547,
@@ -154,53 +154,53 @@
154154
}
155155
},
156156
"finalization_finalization_test.json": {
157-
"total_opcodes": 60388524,
157+
"total_opcodes": 60332240,
158158
"opcodes": {
159-
"add": 16282532,
160-
"sw": 11991752,
161-
"lw": 11967507,
162-
"and": 4634085,
163-
"bltu": 3558455,
164-
"bne": 2844125,
165-
"jalr": 2694135,
166-
"sltu": 1478387,
167-
"beq": 1463754,
168-
"auipc": 1349325,
169-
"or": 773366,
170-
"ecall": 549590,
171-
"xor": 409811,
172-
"sb": 160378,
173-
"lbu": 80098,
174-
"sub": 54086,
175-
"sll": 24614,
176-
"mul": 21458,
177-
"srl": 18058,
178-
"mulhu": 8605,
179-
"bgeu": 8499,
180-
"jal": 7316,
181-
"blt": 3840,
159+
"add": 16267715,
160+
"sw": 11982112,
161+
"lw": 11957780,
162+
"and": 4629027,
163+
"bltu": 3555662,
164+
"bne": 2841325,
165+
"jalr": 2691236,
166+
"sltu": 1476623,
167+
"beq": 1462263,
168+
"auipc": 1347865,
169+
"or": 772070,
170+
"ecall": 549206,
171+
"xor": 409355,
172+
"sb": 159922,
173+
"lbu": 79504,
174+
"sub": 53956,
175+
"sll": 24251,
176+
"mul": 21466,
177+
"srl": 17914,
178+
"mulhu": 8609,
179+
"bgeu": 8507,
180+
"jal": 7309,
181+
"blt": 3839,
182182
"slt": 1920,
183183
"lb": 1653,
184-
"sra": 1142,
184+
"sra": 1118,
185185
"sh": 22,
186186
"bge": 9,
187187
"lh": 1,
188188
"lhu": 1
189189
},
190-
"total_syscalls": 549590,
190+
"total_syscalls": 549206,
191191
"syscalls": {
192-
"bls12381_fp_mul": 240420,
193-
"bls12381_fp_add": 110144,
192+
"bls12381_fp_mul": 240216,
193+
"bls12381_fp_add": 110012,
194194
"bls12381_fp2_add": 84845,
195195
"bls12381_fp2_mul": 50257,
196196
"bls12381_fp2_sub": 35618,
197-
"bls12381_fp_sub": 24689,
197+
"bls12381_fp_sub": 24659,
198198
"bls12381_double": 2772,
199199
"bls12381_add": 220,
200200
"uint256_mul": 146,
201-
"hint_len": 141,
202-
"hint_read": 141,
203-
"enter_unconstrained": 118,
201+
"hint_len": 135,
202+
"hint_read": 135,
203+
"enter_unconstrained": 112,
204204
"sha_compress": 25,
205205
"sha_extend": 25,
206206
"write": 12,

0 commit comments

Comments
 (0)