Skip to content

Commit 7869dbc

Browse files
committed
feat: prover optimization for factorizable sumcheck (MLE rounds)
univariate skip will be handled in a follow up PR to do zerocheck and logup separately
1 parent b51141e commit 7869dbc

File tree

3 files changed

+172
-108
lines changed

3 files changed

+172
-108
lines changed

crates/stark-backend-v2/src/poly_common.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,10 @@ impl<F: TwoAdicField> UnivariatePoly<F> {
303303
}
304304

305305
#[instrument(level = "debug", skip_all)]
306-
pub fn lagrange_interpolate(points: &[F], evals: &[F]) -> Self {
306+
pub fn lagrange_interpolate<BF: Field>(points: &[BF], evals: &[F]) -> Self
307+
where
308+
F: ExtensionField<BF>,
309+
{
307310
assert_eq!(points.len(), evals.len());
308311
let len = points.len();
309312

crates/stark-backend-v2/src/prover/logup_zerocheck/cpu.rs

Lines changed: 75 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::{
2121
poly_common::{eval_eq_mle, eval_eq_sharp_uni, eval_eq_uni, UnivariatePoly},
2222
prover::{
2323
logup_zerocheck::EvalHelper,
24-
poly::evals_eq_hypercube,
24+
poly::evals_eq_hypercubes,
2525
stacked_pcs::StackedLayout,
2626
sumcheck::{
2727
batch_fold_mle_evals, batch_fold_ple_evals, fold_ple_evals, sumcheck_round0_deg,
@@ -54,21 +54,22 @@ pub struct LogupZerocheckCpu<'a> {
5454
// Available after GKR:
5555
pub xi: Vec<EF>,
5656
lambda_pows: Vec<EF>,
57-
58-
pub eq_xi_per_trace: Vec<ColMajorMatrix<EF>>,
59-
pub eq_sharp_per_trace: Vec<ColMajorMatrix<EF>>,
57+
// T -> segment tree of eq(xi[j..1+n_T]) for j=1..=n_T in _reverse_ layout
58+
eq_xi_per_trace: Vec<Vec<EF>>,
6059
eq_3b_per_trace: Vec<Vec<EF>>,
6160
sels_per_trace_base: Vec<ColMajorMatrix<F>>,
6261
// After univariate round 0:
6362
pub mat_evals_per_trace: Vec<Vec<ColMajorMatrix<EF>>>,
6463
pub sels_per_trace: Vec<ColMajorMatrix<EF>>,
6564
// Stores \hat{f}(\vec r_n) * r_{n+1} .. r_{round-1} for polys f that are "done" in the batch
6665
// sumcheck
67-
zerocheck_tilde_evals: Vec<EF>,
68-
logup_tilde_evals: Vec<[EF; 2]>,
66+
pub(crate) zerocheck_tilde_evals: Vec<EF>,
67+
pub(crate) logup_tilde_evals: Vec<[EF; 2]>,
6968

7069
// In round `j`, contains `s_{j-1}(r_{j-1})`
71-
prev_s_eval: EF,
70+
pub(crate) prev_s_eval: EF,
71+
pub(crate) eq_ns: Vec<EF>,
72+
pub(crate) eq_sharp_ns: Vec<EF>,
7273
}
7374

7475
impl<'a> LogupZerocheckCpu<'a> {
@@ -192,13 +193,14 @@ impl<'a> LogupZerocheckCpu<'a> {
192193
lambda_pows: vec![],
193194
sels_per_trace_base: vec![],
194195
eq_xi_per_trace: vec![],
195-
eq_sharp_per_trace: vec![],
196196
eq_3b_per_trace: vec![],
197197
mat_evals_per_trace: vec![],
198198
sels_per_trace: vec![],
199199
zerocheck_tilde_evals,
200200
logup_tilde_evals,
201201
prev_s_eval: EF::ZERO,
202+
eq_ns: Vec::with_capacity(n_max + 1),
203+
eq_sharp_ns: Vec::with_capacity(n_max + 1),
202204
}
203205
}
204206

@@ -264,8 +266,7 @@ impl<'a> LogupZerocheckCpu<'a> {
264266
// PERF[jpw]: might be able to share computations between eq_xi, eq_sharp
265267
// computations the eq(xi, -) evaluations on hyperprism for
266268
// zerocheck
267-
let eq_xi = evals_eq_hypercube(&xi[l_skip..l_skip + n_lift]);
268-
ColMajorMatrix::new(eq_xi, 1)
269+
evals_eq_hypercubes(n_lift, xi[l_skip..l_skip + n_lift].iter().rev())
269270
})
270271
.collect();
271272

@@ -306,7 +307,7 @@ impl<'a> LogupZerocheckCpu<'a> {
306307
let trace_ctx = &ctx.per_trace[trace_idx].1;
307308
let n_lift = log2_strict_usize(trace_ctx.height()).saturating_sub(l_skip);
308309
let mats = &helper.view_mats(trace_ctx);
309-
let eq_xi = self.eq_xi_per_trace[trace_idx].column(0);
310+
let eq_xi = &self.eq_xi_per_trace[trace_idx][(1 << n_lift) - 1..(2 << n_lift) - 1];
310311
let sels = self.sels_per_trace_base[trace_idx].as_view();
311312
let mut parts = vec![(sels.into(), false)];
312313
parts.extend_from_slice(mats);
@@ -349,7 +350,7 @@ impl<'a> LogupZerocheckCpu<'a> {
349350
let log_height = log2_strict_usize(trace_ctx.height());
350351
let n_lift = log_height.saturating_sub(l_skip);
351352
let mats = &helper.view_mats(trace_ctx);
352-
let eq_xi = self.eq_xi_per_trace[trace_idx].column(0);
353+
let eq_xi = &self.eq_xi_per_trace[trace_idx][(1 << n_lift) - 1..(2 << n_lift) - 1];
353354
let eq_3bs = &self.eq_3b_per_trace[trace_idx];
354355
let sels = self.sels_per_trace_base[trace_idx].as_view();
355356
let mut parts = vec![(sels.into(), false)];
@@ -400,39 +401,31 @@ impl<'a> LogupZerocheckCpu<'a> {
400401
batch_fold_ple_evals(l_skip, take(&mut self.sels_per_trace_base), false, r_0);
401402
let eq_r0 = eval_eq_uni(l_skip, self.xi[0], r_0);
402403
let eq_sharp_r0 = eval_eq_sharp_uni(&self.omega_skip_pows, &self.xi[..l_skip], r_0);
403-
// Define eq^\sharp_D(xi[0], r0) * eq_{H_n}(xi[1..1+n], x) and also update eq_D(xi[0], r0) *
404-
// eq_{H_n}(xi[1..1+n], x)
405-
self.eq_sharp_per_trace = self
406-
.eq_xi_per_trace
407-
.par_iter_mut()
408-
.map(|eq| {
409-
let eq_sharp_evals = eq
410-
.values
411-
.par_iter_mut()
412-
.map(|x| {
413-
let eq = *x;
414-
*x *= eq_r0;
415-
eq * eq_sharp_r0
416-
})
417-
.collect();
418-
ColMajorMatrix::new(eq_sharp_evals, 1)
419-
})
420-
.collect();
404+
self.eq_ns.push(eq_r0);
405+
self.eq_sharp_ns.push(eq_sharp_r0);
406+
self.eq_xi_per_trace.iter_mut().for_each(|eq| {
407+
// trim the back (which corresponds to r_{j-1}) because we don't need it anymore
408+
if eq.len() > 1 {
409+
eq.truncate(eq.len() / 2);
410+
}
411+
});
421412
}
422413

423-
/// Returns length `3 * num_airs_present` polynomials, each evaluated at `1..=s_deg`.
414+
/// Returns length `3 * num_airs_present` polynomials, each polynomial either evaluated at
415+
/// `1,...,deg(s')` or at `1` if a linear term (terms in front-loaded sumcheck that have reached
416+
/// exhaustion)
424417
pub fn sumcheck_polys_eval(&mut self, round: usize, r_prev: EF) -> Vec<Vec<EF>> {
425-
// PERF[jpw]: use per AIR s_deg
426-
let s_deg = self.s_deg;
427-
let s_zerocheck_evals: Vec<Vec<EF>> = parizip!(
418+
// sp = s'
419+
let sp_deg = self.constraint_degree;
420+
let sp_zerocheck_evals: Vec<Vec<EF>> = parizip!(
428421
&self.eval_helpers,
429422
&mut self.zerocheck_tilde_evals,
430423
&self.n_per_trace,
431424
&self.mat_evals_per_trace,
432425
&self.sels_per_trace,
433426
&self.eq_xi_per_trace
434427
)
435-
.map(|(helper, tilde_eval, &n, mats, sels, eq_xi)| {
428+
.map(|(helper, tilde_eval, &n, mats, sels, eq_xi_tree)| {
436429
let n_lift = n.max(0) as usize;
437430
if round > n_lift {
438431
if round == n_lift + 1 {
@@ -441,48 +434,44 @@ impl<'a> LogupZerocheckCpu<'a> {
441434
.chain(mats)
442435
.map(|mat| mat.columns().map(|c| c[0]).collect_vec())
443436
.collect_vec();
444-
*tilde_eval =
445-
eq_xi.column(0)[0] * helper.acc_constraints(&parts, &self.lambda_pows);
437+
// eq(xi, \vect r_{round-1})
438+
let eq_r_acc = *self.eq_ns.last().unwrap();
439+
*tilde_eval = eq_r_acc * helper.acc_constraints(&parts, &self.lambda_pows);
446440
} else {
447441
*tilde_eval *= r_prev;
448442
};
449-
(1..=s_deg)
450-
.map(|x| *tilde_eval * F::from_canonical_usize(x))
451-
.collect()
443+
vec![*tilde_eval]
452444
} else {
453-
let parts = iter::empty()
454-
.chain([eq_xi, sels])
445+
let log_num_y = n_lift - round;
446+
let num_y = 1 << log_num_y;
447+
let eq_xi = &eq_xi_tree[num_y - 1..];
448+
let parts = iter::once(sels)
455449
.chain(mats)
456450
.map(|m| m.as_view())
457451
.collect_vec();
458-
let [s] = sumcheck_round_poly_evals(
459-
n_lift - (round - 1),
460-
s_deg,
461-
&parts,
462-
|_x, _y, row_parts| {
463-
let eq = row_parts[0][0];
464-
let constraint_eval =
465-
helper.acc_constraints(&row_parts[1..], &self.lambda_pows);
452+
let [s] =
453+
sumcheck_round_poly_evals(log_num_y + 1, sp_deg, &parts, |_x, y, row_parts| {
454+
let eq = eq_xi[y];
455+
let constraint_eval = helper.acc_constraints(row_parts, &self.lambda_pows);
466456
[eq * constraint_eval]
467-
},
468-
);
457+
});
469458
s
470459
}
471460
})
472461
.collect();
473462

474-
let s_logup_evals: Vec<Vec<EF>> = parizip!(
463+
let sp_logup_evals: Vec<Vec<EF>> = parizip!(
475464
&self.eval_helpers,
476465
&mut self.logup_tilde_evals,
477466
&self.n_per_trace,
478467
&self.mat_evals_per_trace,
479468
&self.sels_per_trace,
480-
&self.eq_sharp_per_trace,
469+
&self.eq_xi_per_trace,
481470
&self.eq_3b_per_trace
482471
)
483-
.flat_map(|(helper, tilde_eval, &n, mats, sels, eq_sharp, eq_3bs)| {
472+
.flat_map(|(helper, tilde_eval, &n, mats, sels, eq_xi_tree, eq_3bs)| {
484473
if helper.interactions.is_empty() {
485-
return [vec![EF::ZERO; s_deg], vec![EF::ZERO; s_deg]];
474+
return [vec![EF::ZERO; sp_deg], vec![EF::ZERO; sp_deg]];
486475
}
487476
let n_lift = n.max(0) as usize;
488477
let norm_factor_denom = 1 << (-n).max(0);
@@ -494,38 +483,32 @@ impl<'a> LogupZerocheckCpu<'a> {
494483
.chain(mats)
495484
.map(|mat| mat.columns().map(|c| c[0]).collect_vec())
496485
.collect_vec();
497-
let eq = eq_sharp.column(0)[0];
486+
let eq_sharp_r_acc = *self.eq_sharp_ns.last().unwrap();
498487
*tilde_eval = helper
499488
.acc_interactions(&parts, &self.beta_pows, eq_3bs)
500-
.map(|x| eq * x);
489+
.map(|x| eq_sharp_r_acc * x);
501490
tilde_eval[0] *= norm_factor;
502491
} else {
503492
for x in tilde_eval.iter_mut() {
504493
*x *= r_prev;
505494
}
506495
};
507-
tilde_eval.map(|tilde_eval| {
508-
(1..=s_deg)
509-
.map(|x| tilde_eval * F::from_canonical_usize(x))
510-
.collect()
511-
})
496+
tilde_eval.map(|tilde_eval| vec![tilde_eval])
512497
} else {
513-
let parts = iter::empty()
514-
.chain([eq_sharp, sels])
498+
let parts = iter::once(sels)
515499
.chain(mats)
516500
.map(|m| m.as_view())
517501
.collect_vec();
518-
let [mut numer, denom] = sumcheck_round_poly_evals(
519-
n_lift - (round - 1),
520-
s_deg,
521-
&parts,
522-
|_x, _y, row_parts| {
523-
let eq_sharp = row_parts[0][0];
502+
let log_num_y = n_lift - round;
503+
let num_y = 1 << log_num_y;
504+
let eq_xi = &eq_xi_tree[num_y - 1..];
505+
let [mut numer, denom] =
506+
sumcheck_round_poly_evals(log_num_y + 1, sp_deg, &parts, |_x, y, row_parts| {
507+
let eq = eq_xi[y];
524508
helper
525-
.acc_interactions(&row_parts[1..], &self.beta_pows, eq_3bs)
526-
.map(|eval| eq_sharp * eval)
527-
},
528-
);
509+
.acc_interactions(row_parts, &self.beta_pows, eq_3bs)
510+
.map(|eval| eq * eval)
511+
});
529512
for p in &mut numer {
530513
*p *= norm_factor;
531514
}
@@ -534,21 +517,32 @@ impl<'a> LogupZerocheckCpu<'a> {
534517
})
535518
.collect();
536519

537-
s_logup_evals.into_iter().chain(s_zerocheck_evals).collect()
520+
sp_logup_evals
521+
.into_iter()
522+
.chain(sp_zerocheck_evals)
523+
.collect()
538524
}
539525

540-
pub fn fold_mle_evals(&mut self, _round: usize, r_round: EF) {
526+
pub fn fold_mle_evals(&mut self, round: usize, r_round: EF) {
541527
self.mat_evals_per_trace = take(&mut self.mat_evals_per_trace)
542528
.into_iter()
543529
.map(|mats| batch_fold_mle_evals(mats, r_round))
544530
.collect_vec();
545531
self.sels_per_trace = batch_fold_mle_evals(take(&mut self.sels_per_trace), r_round);
546-
self.eq_xi_per_trace = batch_fold_mle_evals(take(&mut self.eq_xi_per_trace), r_round);
547-
self.eq_sharp_per_trace = batch_fold_mle_evals(take(&mut self.eq_sharp_per_trace), r_round);
532+
self.eq_xi_per_trace.par_iter_mut().for_each(|eq| {
533+
// trim the back (which corresponds to r_{j-1}) because we don't need it anymore
534+
if eq.len() > 1 {
535+
eq.truncate(eq.len() / 2);
536+
}
537+
});
538+
let xi = self.xi[self.l_skip + round - 1];
539+
let eq_r = eval_eq_mle(&[xi], &[r_round]);
540+
self.eq_ns.push(self.eq_ns[round - 1] * eq_r);
541+
self.eq_sharp_ns.push(self.eq_sharp_ns[round - 1] * eq_r);
548542

549543
#[allow(unused_variables)]
550544
#[cfg(debug_assertions)]
551-
if tracing::enabled!(tracing::Level::DEBUG) && _round == self.n_max {
545+
if tracing::enabled!(tracing::Level::DEBUG) && round == self.n_max {
552546
use itertools::izip;
553547

554548
for (trace_idx, (helper, &n, mats, sels, eq_xi)) in izip!(
@@ -568,12 +562,11 @@ impl<'a> LogupZerocheckCpu<'a> {
568562
debug!(%trace_idx, %expr, "constraints_eval");
569563
}
570564

571-
for (trace_idx, (helper, &n, mats, sels, eq_sharp, eq_3bs)) in izip!(
565+
for (trace_idx, (helper, &n, mats, sels, eq_3bs)) in izip!(
572566
&self.eval_helpers,
573567
&self.n_per_trace,
574568
&self.mat_evals_per_trace,
575569
&self.sels_per_trace,
576-
&self.eq_sharp_per_trace,
577570
&self.eq_3b_per_trace
578571
)
579572
.enumerate()

0 commit comments

Comments
 (0)