@@ -185,35 +185,29 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, purportedV
185185// zeroCheckClaims is a claim for sumcheck (prover side).
186186// It checks that the polynomial ∑ᵢ cⁱ eq (-, xᵢ) wᵢ(-) sums to the expected value,
187187// where the sum runs over all (wire v, claim source s) pairs in the level.
188- // Within each claim group, wire eq tables differ by a constant stride, so only one
189- // eq table per group is stored. Wire contributions are accumulated via Horner's method.
188+ // Each wire has its own eq table with the batching coefficients baked in.
190189type zeroCheckClaims struct {
191190 zeroCheckBase
192- eqs []polynomial.MultiLin // one eq table per claim group
193- strides []{{ .ElementType }} // per-group Horner stride (foldingCoeff^nbSources)
194- nbGroupWires []int // number of wires in each group
191+ eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points
195192}
196193
197- // roundPolynomial computes gⱼ = ∑ₕ ∑_groups eq_g(Xⱼ, h... ) · ∑_w stride^w · gate_w(inputs(Xⱼ, h... )).
198- // Within each group, wire contributions are accumulated via Horner's method in stride,
199- // then multiplied by the group's eq value. This avoids per-wire eq tables.
194+ // roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h... ) · gateᵥ(inputs(Xⱼ, h... )).
200195// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ... , gⱼ(deg(gⱼ)).
201196// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁).
202197// By convention, g₀ is a constant polynomial equal to the claimed sum.
203198func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
204199 level := c.resources.schedule [c.levelI ]. (constraint.GkrSumcheckLevel )
205200 degree := c.resources.circuit.ZeroCheckDegree (level)
206201 nbUniqueInputs := len (c.input )
207- nbWires := len (c.gateEvaluatorPools )
208- nbGroups := len (c.eqs )
202+ nbWires := len (c.eqs )
209203
210204 // Both eqs and input are multilinear, thus linear in Xⱼ.
211205 // For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly
212206 // from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ... , degree.
213- // Layout: [groupEq ₀, ... , groupEq_{G -1}, input₀, ... , input_{K -1}]
214- ml := make([]polynomial.MultiLin , nbGroups +nbUniqueInputs)
207+ // Layout: [eq ₀, eq ₁, ... , eq_{nbWires -1}, input₀, input₁, ... , input_{nbUniqueInputs -1}]
208+ ml := make([]polynomial.MultiLin , nbWires +nbUniqueInputs)
215209 copy(ml, c.eqs )
216- copy(ml[nbGroups :], c.input )
210+ copy(ml[nbWires :], c.input )
217211
218212 sumSize := len (c.eqs [0]) / 2
219213
@@ -254,26 +248,13 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
254248 eIndex := 0 // start of the current row's eq evaluations
255249 nextEIndex := len (ml)
256250 for d := range degree {
257- wireI := 0
258- for g, nbW := range c.nbGroupWires {
259- // Horner accumulation: gate(0) + stride·(gate(1) + stride·(... + stride·gate(W-1)))
260- lastW := wireI + nbW - 1
261- for _, inputI := range c.inputIndices [lastW] {
262- evaluators[lastW].pushInput (mlEvals[eIndex+nbGroups+inputI])
263- }
264- var groupSum {{ .ElementType }}
265- groupSum.Set (evaluators[lastW].evaluate ())
266- for w := lastW - 1; w >= wireI; w-- {
267- groupSum.Mul (&groupSum, &c.strides [g])
268- for _, inputI := range c.inputIndices [w] {
269- evaluators[w].pushInput (mlEvals[eIndex+nbGroups+inputI])
270- }
271- groupSum.Add (&groupSum, evaluators[w].evaluate ())
251+ for w := range nbWires {
252+ for _, inputI := range c.inputIndices [w] {
253+ evaluators[w].pushInput (mlEvals[eIndex+nbWires+inputI])
272254 }
273-
274- groupSum.Mul (&groupSum, &mlEvals[eIndex+g])
275- res[d].Add (&res[d], &groupSum) // collect contributions into the sum from start to end
276- wireI += nbW
255+ summand := evaluators[w].evaluate ()
256+ summand.Mul (summand, &mlEvals[eIndex+w])
257+ res[d].Add (&res[d], summand) // collect contributions into the sum from start to end
277258 }
278259 eIndex, nextEIndex = nextEIndex, nextEIndex+len (ml)
279260 }
@@ -295,7 +276,7 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
295276 return p
296277}
297278
298- // roundFold folds all input and per-group eq polynomials at the verifier challenge r.
279+ // roundFold folds all input and eq polynomials at the verifier challenge r.
299280// After this call , j ← j+1 and rⱼ = r.
300281func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) {
301282 const minBlockSize = 512
@@ -443,13 +424,11 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof {
443424
444425 level := r.schedule [levelI]
445426 eqLength := 1 << r.nbVars
446- groups := level.ClaimGroups ()
447- claims.eqs = make([]polynomial.MultiLin , len (groups))
448- claims.strides = make([]{{ .ElementType }}, len (groups))
449- claims.nbGroupWires = make([]int, len (groups))
427+ claims.eqs = make([]polynomial.MultiLin , len (claims.gateEvaluatorPools ))
450428 var alpha {{ .ElementType }}
451429 alpha.SetOne ()
452- for g, group := range groups {
430+ levelWireI := 0
431+ for _, group := range level.ClaimGroups () {
453432 nbSources := len (group.ClaimSources )
454433
455434 groupEq := polynomial.MultiLin (r.memPool.Make (eqLength))
@@ -473,12 +452,18 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof {
473452 stride.Mul (&stride, &claims.foldingCoeff )
474453 }
475454
476- claims.eqs [g ] = groupEq
477- claims .strides [g] = stride
478- claims .nbGroupWires [g] = len (group .Wires )
455+ claims.eqs [levelWireI ] = groupEq
456+ levelWireI++
457+ alpha .Mul (&alpha, &stride )
479458
480- // Advance alpha past all wires in this group
481- for range len (group.Wires ) {
459+ for w := 1; w < len (group.Wires ); w++ {
460+ claims.eqs [levelWireI] = polynomial.MultiLin (r.memPool.Make (eqLength))
461+ r.workers.Submit (eqLength, func(start, end int) {
462+ for i := start; i < end ; i++ {
463+ claims.eqs [levelWireI][i].Mul (&claims.eqs [levelWireI-1][i], &stride)
464+ }
465+ }, 512).Wait ()
466+ levelWireI++
482467 alpha.Mul (&alpha, &stride)
483468 }
484469 }
0 commit comments