@@ -21,7 +21,7 @@ use crate::{
2121 poly_common:: { eval_eq_mle, eval_eq_sharp_uni, eval_eq_uni, UnivariatePoly } ,
2222 prover:: {
2323 logup_zerocheck:: EvalHelper ,
24- poly:: evals_eq_hypercube ,
24+ poly:: evals_eq_hypercubes ,
2525 stacked_pcs:: StackedLayout ,
2626 sumcheck:: {
2727 batch_fold_mle_evals, batch_fold_ple_evals, fold_ple_evals, sumcheck_round0_deg,
@@ -54,21 +54,22 @@ pub struct LogupZerocheckCpu<'a> {
5454 // Available after GKR:
5555 pub xi : Vec < EF > ,
5656 lambda_pows : Vec < EF > ,
57-
58- pub eq_xi_per_trace : Vec < ColMajorMatrix < EF > > ,
59- pub eq_sharp_per_trace : Vec < ColMajorMatrix < EF > > ,
57+ // T -> segment tree of eq(xi[j..1+n_T]) for j=1..=n_T in _reverse_ layout
58+ eq_xi_per_trace : Vec < Vec < EF > > ,
6059 eq_3b_per_trace : Vec < Vec < EF > > ,
6160 sels_per_trace_base : Vec < ColMajorMatrix < F > > ,
6261 // After univariate round 0:
6362 pub mat_evals_per_trace : Vec < Vec < ColMajorMatrix < EF > > > ,
6463 pub sels_per_trace : Vec < ColMajorMatrix < EF > > ,
6564 // Stores \hat{f}(\vec r_n) * r_{n+1} .. r_{round-1} for polys f that are "done" in the batch
6665 // sumcheck
67- zerocheck_tilde_evals : Vec < EF > ,
68- logup_tilde_evals : Vec < [ EF ; 2 ] > ,
66+ pub ( crate ) zerocheck_tilde_evals : Vec < EF > ,
67+ pub ( crate ) logup_tilde_evals : Vec < [ EF ; 2 ] > ,
6968
7069 // In round `j`, contains `s_{j-1}(r_{j-1})`
71- prev_s_eval : EF ,
70+ pub ( crate ) prev_s_eval : EF ,
71+ pub ( crate ) eq_ns : Vec < EF > ,
72+ pub ( crate ) eq_sharp_ns : Vec < EF > ,
7273}
7374
7475impl < ' a > LogupZerocheckCpu < ' a > {
@@ -192,13 +193,14 @@ impl<'a> LogupZerocheckCpu<'a> {
192193 lambda_pows : vec ! [ ] ,
193194 sels_per_trace_base : vec ! [ ] ,
194195 eq_xi_per_trace : vec ! [ ] ,
195- eq_sharp_per_trace : vec ! [ ] ,
196196 eq_3b_per_trace : vec ! [ ] ,
197197 mat_evals_per_trace : vec ! [ ] ,
198198 sels_per_trace : vec ! [ ] ,
199199 zerocheck_tilde_evals,
200200 logup_tilde_evals,
201201 prev_s_eval : EF :: ZERO ,
202+ eq_ns : Vec :: with_capacity ( n_max + 1 ) ,
203+ eq_sharp_ns : Vec :: with_capacity ( n_max + 1 ) ,
202204 }
203205 }
204206
@@ -264,8 +266,7 @@ impl<'a> LogupZerocheckCpu<'a> {
264266 // PERF[jpw]: might be able to share computations between eq_xi, eq_sharp
265267 // computations the eq(xi, -) evaluations on hyperprism for
266268 // zerocheck
267- let eq_xi = evals_eq_hypercube ( & xi[ l_skip..l_skip + n_lift] ) ;
268- ColMajorMatrix :: new ( eq_xi, 1 )
269+ evals_eq_hypercubes ( n_lift, xi[ l_skip..l_skip + n_lift] . iter ( ) . rev ( ) )
269270 } )
270271 . collect ( ) ;
271272
@@ -306,7 +307,7 @@ impl<'a> LogupZerocheckCpu<'a> {
306307 let trace_ctx = & ctx. per_trace [ trace_idx] . 1 ;
307308 let n_lift = log2_strict_usize ( trace_ctx. height ( ) ) . saturating_sub ( l_skip) ;
308309 let mats = & helper. view_mats ( trace_ctx) ;
309- let eq_xi = self . eq_xi_per_trace [ trace_idx] . column ( 0 ) ;
310+ let eq_xi = & self . eq_xi_per_trace [ trace_idx] [ ( 1 << n_lift ) - 1 .. ( 2 << n_lift ) - 1 ] ;
310311 let sels = self . sels_per_trace_base [ trace_idx] . as_view ( ) ;
311312 let mut parts = vec ! [ ( sels. into( ) , false ) ] ;
312313 parts. extend_from_slice ( mats) ;
@@ -349,7 +350,7 @@ impl<'a> LogupZerocheckCpu<'a> {
349350 let log_height = log2_strict_usize ( trace_ctx. height ( ) ) ;
350351 let n_lift = log_height. saturating_sub ( l_skip) ;
351352 let mats = & helper. view_mats ( trace_ctx) ;
352- let eq_xi = self . eq_xi_per_trace [ trace_idx] . column ( 0 ) ;
353+ let eq_xi = & self . eq_xi_per_trace [ trace_idx] [ ( 1 << n_lift ) - 1 .. ( 2 << n_lift ) - 1 ] ;
353354 let eq_3bs = & self . eq_3b_per_trace [ trace_idx] ;
354355 let sels = self . sels_per_trace_base [ trace_idx] . as_view ( ) ;
355356 let mut parts = vec ! [ ( sels. into( ) , false ) ] ;
@@ -400,39 +401,31 @@ impl<'a> LogupZerocheckCpu<'a> {
400401 batch_fold_ple_evals ( l_skip, take ( & mut self . sels_per_trace_base ) , false , r_0) ;
401402 let eq_r0 = eval_eq_uni ( l_skip, self . xi [ 0 ] , r_0) ;
402403 let eq_sharp_r0 = eval_eq_sharp_uni ( & self . omega_skip_pows , & self . xi [ ..l_skip] , r_0) ;
403- // Define eq^\sharp_D(xi[0], r0) * eq_{H_n}(xi[1..1+n], x) and also update eq_D(xi[0], r0) *
404- // eq_{H_n}(xi[1..1+n], x)
405- self . eq_sharp_per_trace = self
406- . eq_xi_per_trace
407- . par_iter_mut ( )
408- . map ( |eq| {
409- let eq_sharp_evals = eq
410- . values
411- . par_iter_mut ( )
412- . map ( |x| {
413- let eq = * x;
414- * x *= eq_r0;
415- eq * eq_sharp_r0
416- } )
417- . collect ( ) ;
418- ColMajorMatrix :: new ( eq_sharp_evals, 1 )
419- } )
420- . collect ( ) ;
404+ self . eq_ns . push ( eq_r0) ;
405+ self . eq_sharp_ns . push ( eq_sharp_r0) ;
406+ self . eq_xi_per_trace . iter_mut ( ) . for_each ( |eq| {
407+ // trim the back (which corresponds to r_{j-1}) because we don't need it anymore
408+ if eq. len ( ) > 1 {
409+ eq. truncate ( eq. len ( ) / 2 ) ;
410+ }
411+ } ) ;
421412 }
422413
423- /// Returns length `3 * num_airs_present` polynomials, each evaluated at `1..=s_deg`.
414+ /// Returns length `3 * num_airs_present` polynomials, each polynomial either evaluated at
415+ /// `1,...,deg(s')` or at `1` if a linear term (terms in front-loaded sumcheck that have reached
416+ /// exhaustion)
424417 pub fn sumcheck_polys_eval ( & mut self , round : usize , r_prev : EF ) -> Vec < Vec < EF > > {
425- // PERF[jpw]: use per AIR s_deg
426- let s_deg = self . s_deg ;
427- let s_zerocheck_evals : Vec < Vec < EF > > = parizip ! (
418+ // sp = s'
419+ let sp_deg = self . constraint_degree ;
420+ let sp_zerocheck_evals : Vec < Vec < EF > > = parizip ! (
428421 & self . eval_helpers,
429422 & mut self . zerocheck_tilde_evals,
430423 & self . n_per_trace,
431424 & self . mat_evals_per_trace,
432425 & self . sels_per_trace,
433426 & self . eq_xi_per_trace
434427 )
435- . map ( |( helper, tilde_eval, & n, mats, sels, eq_xi ) | {
428+ . map ( |( helper, tilde_eval, & n, mats, sels, eq_xi_tree ) | {
436429 let n_lift = n. max ( 0 ) as usize ;
437430 if round > n_lift {
438431 if round == n_lift + 1 {
@@ -441,48 +434,44 @@ impl<'a> LogupZerocheckCpu<'a> {
441434 . chain ( mats)
442435 . map ( |mat| mat. columns ( ) . map ( |c| c[ 0 ] ) . collect_vec ( ) )
443436 . collect_vec ( ) ;
444- * tilde_eval =
445- eq_xi. column ( 0 ) [ 0 ] * helper. acc_constraints ( & parts, & self . lambda_pows ) ;
437+ // eq(xi, \vect r_{round-1})
438+ let eq_r_acc = * self . eq_ns . last ( ) . unwrap ( ) ;
439+ * tilde_eval = eq_r_acc * helper. acc_constraints ( & parts, & self . lambda_pows ) ;
446440 } else {
447441 * tilde_eval *= r_prev;
448442 } ;
449- ( 1 ..=s_deg)
450- . map ( |x| * tilde_eval * F :: from_canonical_usize ( x) )
451- . collect ( )
443+ vec ! [ * tilde_eval]
452444 } else {
453- let parts = iter:: empty ( )
454- . chain ( [ eq_xi, sels] )
445+ let log_num_y = n_lift - round;
446+ let num_y = 1 << log_num_y;
447+ let eq_xi = & eq_xi_tree[ num_y - 1 ..] ;
448+ let parts = iter:: once ( sels)
455449 . chain ( mats)
456450 . map ( |m| m. as_view ( ) )
457451 . collect_vec ( ) ;
458- let [ s] = sumcheck_round_poly_evals (
459- n_lift - ( round - 1 ) ,
460- s_deg,
461- & parts,
462- |_x, _y, row_parts| {
463- let eq = row_parts[ 0 ] [ 0 ] ;
464- let constraint_eval =
465- helper. acc_constraints ( & row_parts[ 1 ..] , & self . lambda_pows ) ;
452+ let [ s] =
453+ sumcheck_round_poly_evals ( log_num_y + 1 , sp_deg, & parts, |_x, y, row_parts| {
454+ let eq = eq_xi[ y] ;
455+ let constraint_eval = helper. acc_constraints ( row_parts, & self . lambda_pows ) ;
466456 [ eq * constraint_eval]
467- } ,
468- ) ;
457+ } ) ;
469458 s
470459 }
471460 } )
472461 . collect ( ) ;
473462
474- let s_logup_evals : Vec < Vec < EF > > = parizip ! (
463+ let sp_logup_evals : Vec < Vec < EF > > = parizip ! (
475464 & self . eval_helpers,
476465 & mut self . logup_tilde_evals,
477466 & self . n_per_trace,
478467 & self . mat_evals_per_trace,
479468 & self . sels_per_trace,
480- & self . eq_sharp_per_trace ,
469+ & self . eq_xi_per_trace ,
481470 & self . eq_3b_per_trace
482471 )
483- . flat_map ( |( helper, tilde_eval, & n, mats, sels, eq_sharp , eq_3bs) | {
472+ . flat_map ( |( helper, tilde_eval, & n, mats, sels, eq_xi_tree , eq_3bs) | {
484473 if helper. interactions . is_empty ( ) {
485- return [ vec ! [ EF :: ZERO ; s_deg ] , vec ! [ EF :: ZERO ; s_deg ] ] ;
474+ return [ vec ! [ EF :: ZERO ; sp_deg ] , vec ! [ EF :: ZERO ; sp_deg ] ] ;
486475 }
487476 let n_lift = n. max ( 0 ) as usize ;
488477 let norm_factor_denom = 1 << ( -n) . max ( 0 ) ;
@@ -494,38 +483,32 @@ impl<'a> LogupZerocheckCpu<'a> {
494483 . chain ( mats)
495484 . map ( |mat| mat. columns ( ) . map ( |c| c[ 0 ] ) . collect_vec ( ) )
496485 . collect_vec ( ) ;
497- let eq = eq_sharp . column ( 0 ) [ 0 ] ;
486+ let eq_sharp_r_acc = * self . eq_sharp_ns . last ( ) . unwrap ( ) ;
498487 * tilde_eval = helper
499488 . acc_interactions ( & parts, & self . beta_pows , eq_3bs)
500- . map ( |x| eq * x) ;
489+ . map ( |x| eq_sharp_r_acc * x) ;
501490 tilde_eval[ 0 ] *= norm_factor;
502491 } else {
503492 for x in tilde_eval. iter_mut ( ) {
504493 * x *= r_prev;
505494 }
506495 } ;
507- tilde_eval. map ( |tilde_eval| {
508- ( 1 ..=s_deg)
509- . map ( |x| tilde_eval * F :: from_canonical_usize ( x) )
510- . collect ( )
511- } )
496+ tilde_eval. map ( |tilde_eval| vec ! [ tilde_eval] )
512497 } else {
513- let parts = iter:: empty ( )
514- . chain ( [ eq_sharp, sels] )
498+ let parts = iter:: once ( sels)
515499 . chain ( mats)
516500 . map ( |m| m. as_view ( ) )
517501 . collect_vec ( ) ;
518- let [ mut numer , denom ] = sumcheck_round_poly_evals (
519- n_lift - ( round - 1 ) ,
520- s_deg ,
521- & parts ,
522- |_x, _y , row_parts| {
523- let eq_sharp = row_parts [ 0 ] [ 0 ] ;
502+ let log_num_y = n_lift - round ;
503+ let num_y = 1 << log_num_y ;
504+ let eq_xi = & eq_xi_tree [ num_y - 1 .. ] ;
505+ let [ mut numer , denom ] =
506+ sumcheck_round_poly_evals ( log_num_y + 1 , sp_deg , & parts , |_x, y , row_parts| {
507+ let eq = eq_xi [ y ] ;
524508 helper
525- . acc_interactions ( & row_parts[ 1 ..] , & self . beta_pows , eq_3bs)
526- . map ( |eval| eq_sharp * eval)
527- } ,
528- ) ;
509+ . acc_interactions ( row_parts, & self . beta_pows , eq_3bs)
510+ . map ( |eval| eq * eval)
511+ } ) ;
529512 for p in & mut numer {
530513 * p *= norm_factor;
531514 }
@@ -534,21 +517,32 @@ impl<'a> LogupZerocheckCpu<'a> {
534517 } )
535518 . collect ( ) ;
536519
537- s_logup_evals. into_iter ( ) . chain ( s_zerocheck_evals) . collect ( )
520+ sp_logup_evals
521+ . into_iter ( )
522+ . chain ( sp_zerocheck_evals)
523+ . collect ( )
538524 }
539525
540- pub fn fold_mle_evals ( & mut self , _round : usize , r_round : EF ) {
526+ pub fn fold_mle_evals ( & mut self , round : usize , r_round : EF ) {
541527 self . mat_evals_per_trace = take ( & mut self . mat_evals_per_trace )
542528 . into_iter ( )
543529 . map ( |mats| batch_fold_mle_evals ( mats, r_round) )
544530 . collect_vec ( ) ;
545531 self . sels_per_trace = batch_fold_mle_evals ( take ( & mut self . sels_per_trace ) , r_round) ;
546- self . eq_xi_per_trace = batch_fold_mle_evals ( take ( & mut self . eq_xi_per_trace ) , r_round) ;
547- self . eq_sharp_per_trace = batch_fold_mle_evals ( take ( & mut self . eq_sharp_per_trace ) , r_round) ;
532+ self . eq_xi_per_trace . par_iter_mut ( ) . for_each ( |eq| {
533+ // trim the back (which corresponds to r_{j-1}) because we don't need it anymore
534+ if eq. len ( ) > 1 {
535+ eq. truncate ( eq. len ( ) / 2 ) ;
536+ }
537+ } ) ;
538+ let xi = self . xi [ self . l_skip + round - 1 ] ;
539+ let eq_r = eval_eq_mle ( & [ xi] , & [ r_round] ) ;
540+ self . eq_ns . push ( self . eq_ns [ round - 1 ] * eq_r) ;
541+ self . eq_sharp_ns . push ( self . eq_sharp_ns [ round - 1 ] * eq_r) ;
548542
549543 #[ allow( unused_variables) ]
550544 #[ cfg( debug_assertions) ]
551- if tracing:: enabled!( tracing:: Level :: DEBUG ) && _round == self . n_max {
545+ if tracing:: enabled!( tracing:: Level :: DEBUG ) && round == self . n_max {
552546 use itertools:: izip;
553547
554548 for ( trace_idx, ( helper, & n, mats, sels, eq_xi) ) in izip ! (
@@ -568,12 +562,11 @@ impl<'a> LogupZerocheckCpu<'a> {
568562 debug ! ( %trace_idx, %expr, "constraints_eval" ) ;
569563 }
570564
571- for ( trace_idx, ( helper, & n, mats, sels, eq_sharp , eq_3bs) ) in izip ! (
565+ for ( trace_idx, ( helper, & n, mats, sels, eq_3bs) ) in izip ! (
572566 & self . eval_helpers,
573567 & self . n_per_trace,
574568 & self . mat_evals_per_trace,
575569 & self . sels_per_trace,
576- & self . eq_sharp_per_trace,
577570 & self . eq_3b_per_trace
578571 )
579572 . enumerate ( )
0 commit comments