Skip to content

Commit d14bae6

Browse files
committed
feat: prover optimization for factorizable sumcheck (#218)
This is the `eq` optimization pointed out in Gruen Section 3.2, with small generalization to front-loaded batch sumcheck. This reduces the degree of the polynomial the prover evaluates by 1. The univariate skip round 0 is also handled, where zerocheck and logup are handled separately. We will make a separate optimization around the quotient polynomial for zerocheck in a follow up PR. towards INT-5804
1 parent 73b6db0 commit d14bae6

File tree

5 files changed

+517
-418
lines changed

5 files changed

+517
-418
lines changed

crates/stark-backend-v2/src/poly_common.rs

Lines changed: 119 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,29 @@ pub fn eval_eq_uni_at_one<F: Field>(l_skip: usize, x: F) -> F {
5151
res * F::ONE.halve().exp_u64(l_skip as u64)
5252
}
5353

54+
/// Returns `eq_D(x, Z)` as a polynomial in `Z` in coefficient form.
55+
/// Derived from `eq_D(x, Z)` being the Lagrange basis at `x`, which is the character sum over the
56+
/// roots of unity.
57+
///
58+
/// If z in D, then `eq_D(x, z) = 1/N sum_{k=1}^N (x/z)^k = 1/N sum_{k=1}^N x^k
59+
/// z^{N-k}`.
60+
pub fn eq_uni_poly<F, EF>(l_skip: usize, x: EF) -> UnivariatePoly<EF>
61+
where
62+
F: Field,
63+
EF: ExtensionField<F>,
64+
{
65+
let n_inv = F::ONE.halve().exp_u64(l_skip as u64);
66+
let mut coeffs = x
67+
.powers()
68+
.skip(1)
69+
.take(1 << l_skip)
70+
.map(|x_pow| x_pow * n_inv)
71+
.collect_vec();
72+
coeffs.reverse();
73+
coeffs[0] = n_inv.into();
74+
UnivariatePoly::new(coeffs)
75+
}
76+
5477
pub fn eval_in_uni<F: Field>(l_skip: usize, n: isize, z: F) -> F {
5578
debug_assert!(n >= -(l_skip as isize));
5679
if n.is_negative() {
@@ -107,6 +130,11 @@ where
107130
res
108131
}
109132

133+
pub fn eq_sharp_uni_poly<EF: TwoAdicField>(xi_1: &[EF]) -> UnivariatePoly<EF> {
134+
let evals = evals_eq_hypercube_serial(xi_1);
135+
UnivariatePoly::from_evals_idft(&evals)
136+
}
137+
110138
/// `\kappa_\rot(x, y)` should equal `\delta_{x,rot(y)}` on hyperprism.
111139
///
112140
/// `omega_pows` must have length `2^{l_skip}`.
@@ -142,7 +170,7 @@ pub fn eval_eq_rot_cube<F: Field>(x: &[F], y: &[F]) -> (F, F) {
142170
#[derive(Clone, Debug)]
143171
pub struct UnivariatePoly<F>(pub(crate) Vec<F>);
144172

145-
impl<F: TwoAdicField> UnivariatePoly<F> {
173+
impl<F> UnivariatePoly<F> {
146174
pub fn new(coeffs: Vec<F>) -> Self {
147175
Self(coeffs)
148176
}
@@ -158,11 +186,101 @@ impl<F: TwoAdicField> UnivariatePoly<F> {
158186
pub fn into_coeffs(self) -> Vec<F> {
159187
self.0
160188
}
189+
}
161190

191+
impl<F: Field> UnivariatePoly<F> {
162192
pub fn eval_at_point<EF: ExtensionField<F>>(&self, x: EF) -> EF {
163193
horner_eval(&self.0, x)
164194
}
165195

196+
#[instrument(level = "debug", skip_all)]
197+
pub fn lagrange_interpolate<BF: Field>(points: &[BF], evals: &[F]) -> Self
198+
where
199+
F: ExtensionField<BF>,
200+
{
201+
assert_eq!(points.len(), evals.len());
202+
let len = points.len();
203+
204+
// Special case: empty or single evaluation
205+
if len == 0 {
206+
return Self(vec![]);
207+
}
208+
if len == 1 {
209+
return Self(vec![evals[0]]);
210+
}
211+
212+
// Lagrange interpolation algorithm
213+
// P(x) = sum_{i=0}^{len-1} evals[i] * L_i(x)
214+
// where L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])
215+
216+
// Step 1: Compute all denominators (points[i] - points[j]) for i != j
217+
let mut denominators = Vec::with_capacity(len * (len - 1));
218+
for i in 0..len {
219+
for j in 0..len {
220+
if i != j {
221+
denominators.push(points[i] - points[j]);
222+
}
223+
}
224+
}
225+
226+
// Step 2: Batch invert all denominators
227+
let inv_denominators = batch_multiplicative_inverse_serial(&denominators);
228+
229+
// Step 3: Build coefficient form by accumulating Lagrange basis polynomials
230+
let mut coeffs = vec![F::ZERO; len];
231+
232+
// Reusable workspace for Lagrange polynomial computation
233+
let mut lagrange_poly = Vec::with_capacity(len);
234+
235+
#[allow(clippy::needless_range_loop)]
236+
for i in 0..len {
237+
// Skip if evaluation is zero (optimization)
238+
if evals[i] == F::ZERO {
239+
continue;
240+
}
241+
242+
// Build L_i(x) in coefficient form using polynomial multiplication
243+
// L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])
244+
245+
// Start with constant polynomial 1
246+
lagrange_poly.clear();
247+
lagrange_poly.push(F::ONE);
248+
249+
// Get the precomputed inverse denominators for this i
250+
let inv_denom_start = i * (len - 1);
251+
let mut inv_idx = 0;
252+
253+
// Multiply by (x - points[j]) / (points[i] - points[j]) for each j != i
254+
#[allow(clippy::needless_range_loop)]
255+
for j in 0..len {
256+
if i != j {
257+
let scale = inv_denominators[inv_denom_start + inv_idx];
258+
inv_idx += 1;
259+
260+
// Multiply lagrange_poly by (x - points[j]) * scale in place
261+
// This is equivalent to: lagrange_poly * (x - points[j]) * scale
262+
// = lagrange_poly * x * scale - lagrange_poly * points[j] * scale
263+
264+
lagrange_poly.push(F::ZERO); // Extend by one for the new highest degree term
265+
for k in (1..lagrange_poly.len()).rev() {
266+
let prev_coeff = lagrange_poly[k - 1] * scale;
267+
lagrange_poly[k] += prev_coeff;
268+
lagrange_poly[k - 1] = -prev_coeff * points[j];
269+
}
270+
}
271+
}
272+
273+
// Add evals[i] * L_i(x) to the result
274+
for (k, &coeff) in lagrange_poly.iter().enumerate() {
275+
coeffs[k] += evals[i] * coeff;
276+
}
277+
}
278+
279+
Self(coeffs)
280+
}
281+
}
282+
283+
impl<F: TwoAdicField> UnivariatePoly<F> {
166284
/// Computes P(1), P(omega), ..., P(omega^{n-1}).
167285
fn chirp_z(poly: &[F], omega: F, n: usize) -> Vec<F> {
168286
if n == 0 {
@@ -302,89 +420,6 @@ impl<F: TwoAdicField> UnivariatePoly<F> {
302420
Self(res)
303421
}
304422

305-
#[instrument(level = "debug", skip_all)]
306-
pub fn lagrange_interpolate(points: &[F], evals: &[F]) -> Self {
307-
assert_eq!(points.len(), evals.len());
308-
let len = points.len();
309-
310-
// Special case: empty or single evaluation
311-
if len == 0 {
312-
return Self(vec![]);
313-
}
314-
if len == 1 {
315-
return Self(vec![evals[0]]);
316-
}
317-
318-
// Lagrange interpolation algorithm
319-
// P(x) = sum_{i=0}^{len-1} evals[i] * L_i(x)
320-
// where L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])
321-
322-
// Step 1: Compute all denominators (points[i] - points[j]) for i != j
323-
let mut denominators = Vec::with_capacity(len * (len - 1));
324-
for i in 0..len {
325-
for j in 0..len {
326-
if i != j {
327-
denominators.push(points[i] - points[j]);
328-
}
329-
}
330-
}
331-
332-
// Step 2: Batch invert all denominators
333-
let inv_denominators = batch_multiplicative_inverse_serial(&denominators);
334-
335-
// Step 3: Build coefficient form by accumulating Lagrange basis polynomials
336-
let mut coeffs = vec![F::ZERO; len];
337-
338-
// Reusable workspace for Lagrange polynomial computation
339-
let mut lagrange_poly = Vec::with_capacity(len);
340-
341-
#[allow(clippy::needless_range_loop)]
342-
for i in 0..len {
343-
// Skip if evaluation is zero (optimization)
344-
if evals[i] == F::ZERO {
345-
continue;
346-
}
347-
348-
// Build L_i(x) in coefficient form using polynomial multiplication
349-
// L_i(x) = prod_{j != i} (x - points[j]) / (points[i] - points[j])
350-
351-
// Start with constant polynomial 1
352-
lagrange_poly.clear();
353-
lagrange_poly.push(F::ONE);
354-
355-
// Get the precomputed inverse denominators for this i
356-
let inv_denom_start = i * (len - 1);
357-
let mut inv_idx = 0;
358-
359-
// Multiply by (x - points[j]) / (points[i] - points[j]) for each j != i
360-
#[allow(clippy::needless_range_loop)]
361-
for j in 0..len {
362-
if i != j {
363-
let scale = inv_denominators[inv_denom_start + inv_idx];
364-
inv_idx += 1;
365-
366-
// Multiply lagrange_poly by (x - points[j]) * scale in place
367-
// This is equivalent to: lagrange_poly * (x - points[j]) * scale
368-
// = lagrange_poly * x * scale - lagrange_poly * points[j] * scale
369-
370-
lagrange_poly.push(F::ZERO); // Extend by one for the new highest degree term
371-
for k in (1..lagrange_poly.len()).rev() {
372-
let prev_coeff = lagrange_poly[k - 1] * scale;
373-
lagrange_poly[k] += prev_coeff;
374-
lagrange_poly[k - 1] = -prev_coeff * points[j];
375-
}
376-
}
377-
}
378-
379-
// Add evals[i] * L_i(x) to the result
380-
for (k, &coeff) in lagrange_poly.iter().enumerate() {
381-
coeffs[k] += evals[i] * coeff;
382-
}
383-
}
384-
385-
Self(coeffs)
386-
}
387-
388423
/// Constructs the polynomial in coefficient form from its evaluations on a smooth subgroup of
389424
/// `F^*` by performing inverse DFT.
390425
///

crates/stark-backend-v2/src/prover/cpu_backend.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use crate::{
1414
stacked_reduction::{prove_stacked_opening_reduction, StackedReductionCpu},
1515
whir::WhirProver,
1616
ColMajorMatrix, CommittedTraceDataV2, DeviceDataTransporterV2,
17-
DeviceMultiStarkProvingKeyV2, DeviceStarkProvingKeyV2, LogupZerocheckCpu, MultiRapProver,
18-
OpeningProverV2, ProverBackendV2, ProverDeviceV2, ProvingContextV2, TraceCommitterV2,
17+
DeviceMultiStarkProvingKeyV2, DeviceStarkProvingKeyV2, MultiRapProver, OpeningProverV2,
18+
ProverBackendV2, ProverDeviceV2, ProvingContextV2, TraceCommitterV2,
1919
},
2020
Digest, SystemParams, D_EF, EF, F,
2121
};
@@ -69,16 +69,10 @@ impl<TS: FiatShamirTranscript> MultiRapProver<CpuBackendV2, TS> for CpuDeviceV2
6969
transcript: &mut TS,
7070
mpk: &DeviceMultiStarkProvingKeyV2<CpuBackendV2>,
7171
ctx: &ProvingContextV2<CpuBackendV2>,
72-
common_main_pcs_data: &StackedPcsData<F, Digest>,
72+
_common_main_pcs_data: &StackedPcsData<F, Digest>,
7373
) -> ((GkrProof, BatchConstraintProof), Vec<EF>) {
7474
let (gkr_proof, batch_constraint_proof, r) =
75-
prove_zerocheck_and_logup::<_, _, TS, LogupZerocheckCpu>(
76-
self,
77-
transcript,
78-
mpk,
79-
ctx,
80-
common_main_pcs_data,
81-
);
75+
prove_zerocheck_and_logup(transcript, mpk, ctx);
8276
((gkr_proof, batch_constraint_proof), r)
8377
}
8478
}

0 commit comments

Comments
 (0)