Skip to content

Commit b02cd37

Browse files
committed
Revert "perf: per group eq"
This reverts commit 3ce291a.
1 parent 3ce291a commit b02cd37

File tree

6 files changed

+168
-258
lines changed

6 files changed

+168
-258
lines changed

internal/generator/backend/template/gkr/sumcheck.go.tmpl

Lines changed: 28 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -185,35 +185,29 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, purportedV
185185
// zeroCheckClaims is a claim for sumcheck (prover side).
186186
// It checks that the polynomial ∑ᵢ cⁱ eq(-, xᵢ) wᵢ(-) sums to the expected value,
187187
// where the sum runs over all (wire v, claim source s) pairs in the level.
188-
// Within each claim group, wire eq tables differ by a constant stride, so only one
189-
// eq table per group is stored. Wire contributions are accumulated via Horner's method.
188+
// Each wire has its own eq table with the batching coefficients baked in.
190189
type zeroCheckClaims struct {
191190
zeroCheckBase
192-
eqs []polynomial.MultiLin // one eq table per claim group
193-
strides []{{ .ElementType }} // per-group Horner stride (foldingCoeff^nbSources)
194-
nbGroupWires []int // number of wires in each group
191+
eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points
195192
}
196193

197-
// roundPolynomial computes gⱼ = ∑ₕ ∑_groups eq_g(Xⱼ, h...) · ∑_w stride^w · gate_w(inputs(Xⱼ, h...)).
198-
// Within each group, wire contributions are accumulated via Horner's method in stride,
199-
// then multiplied by the group's eq value. This avoids per-wire eq tables.
194+
// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)).
200195
// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)).
201196
// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁).
202197
// By convention, g₀ is a constant polynomial equal to the claimed sum.
203198
func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
204199
level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel)
205200
degree := c.resources.circuit.ZeroCheckDegree(level)
206201
nbUniqueInputs := len(c.input)
207-
nbWires := len(c.gateEvaluatorPools)
208-
nbGroups := len(c.eqs)
202+
nbWires := len(c.eqs)
209203

210204
// Both eqs and input are multilinear, thus linear in Xⱼ.
211205
// For any such f, f(m) = m·(f(1) - f(0)) + f(0), and f(0), f(1) are read directly
212206
// from the bookkeeping tables. This allows stepwise evaluation at Xⱼ = 1, 2, ..., degree.
213-
// Layout: [groupEq₀, ..., groupEq_{G-1}, input₀, ..., input_{K-1}]
214-
ml := make([]polynomial.MultiLin, nbGroups+nbUniqueInputs)
207+
// Layout: [eq₀, eq₁, ..., eq_{nbWires-1}, input₀, input₁, ..., input_{nbUniqueInputs-1}]
208+
ml := make([]polynomial.MultiLin, nbWires+nbUniqueInputs)
215209
copy(ml, c.eqs)
216-
copy(ml[nbGroups:], c.input)
210+
copy(ml[nbWires:], c.input)
217211

218212
sumSize := len(c.eqs[0]) / 2
219213

@@ -254,26 +248,13 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
254248
eIndex := 0 // start of the current row's eq evaluations
255249
nextEIndex := len(ml)
256250
for d := range degree {
257-
wireI := 0
258-
for g, nbW := range c.nbGroupWires {
259-
// Horner accumulation: gate(0) + stride·(gate(1) + stride·(... + stride·gate(W-1)))
260-
lastW := wireI + nbW - 1
261-
for _, inputI := range c.inputIndices[lastW] {
262-
evaluators[lastW].pushInput(mlEvals[eIndex+nbGroups+inputI])
263-
}
264-
var groupSum {{ .ElementType }}
265-
groupSum.Set(evaluators[lastW].evaluate())
266-
for w := lastW - 1; w >= wireI; w-- {
267-
groupSum.Mul(&groupSum, &c.strides[g])
268-
for _, inputI := range c.inputIndices[w] {
269-
evaluators[w].pushInput(mlEvals[eIndex+nbGroups+inputI])
270-
}
271-
groupSum.Add(&groupSum, evaluators[w].evaluate())
251+
for w := range nbWires {
252+
for _, inputI := range c.inputIndices[w] {
253+
evaluators[w].pushInput(mlEvals[eIndex+nbWires+inputI])
272254
}
273-
274-
groupSum.Mul(&groupSum, &mlEvals[eIndex+g])
275-
res[d].Add(&res[d], &groupSum) // collect contributions into the sum from start to end
276-
wireI += nbW
255+
summand := evaluators[w].evaluate()
256+
summand.Mul(summand, &mlEvals[eIndex+w])
257+
res[d].Add(&res[d], summand) // collect contributions into the sum from start to end
277258
}
278259
eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml)
279260
}
@@ -295,7 +276,7 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
295276
return p
296277
}
297278

298-
// roundFold folds all input and per-group eq polynomials at the verifier challenge r.
279+
// roundFold folds all input and eq polynomials at the verifier challenge r.
299280
// After this call, j ← j+1 and rⱼ = r.
300281
func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) {
301282
const minBlockSize = 512
@@ -443,13 +424,11 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof {
443424

444425
level := r.schedule[levelI]
445426
eqLength := 1 << r.nbVars
446-
groups := level.ClaimGroups()
447-
claims.eqs = make([]polynomial.MultiLin, len(groups))
448-
claims.strides = make([]{{ .ElementType }}, len(groups))
449-
claims.nbGroupWires = make([]int, len(groups))
427+
claims.eqs = make([]polynomial.MultiLin, len(claims.gateEvaluatorPools))
450428
var alpha {{ .ElementType }}
451429
alpha.SetOne()
452-
for g, group := range groups {
430+
levelWireI := 0
431+
for _, group := range level.ClaimGroups() {
453432
nbSources := len(group.ClaimSources)
454433

455434
groupEq := polynomial.MultiLin(r.memPool.Make(eqLength))
@@ -473,12 +452,18 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof {
473452
stride.Mul(&stride, &claims.foldingCoeff)
474453
}
475454

476-
claims.eqs[g] = groupEq
477-
claims.strides[g] = stride
478-
claims.nbGroupWires[g] = len(group.Wires)
455+
claims.eqs[levelWireI] = groupEq
456+
levelWireI++
457+
alpha.Mul(&alpha, &stride)
479458

480-
// Advance alpha past all wires in this group
481-
for range len(group.Wires) {
459+
for w := 1; w < len(group.Wires); w++ {
460+
claims.eqs[levelWireI] = polynomial.MultiLin(r.memPool.Make(eqLength))
461+
r.workers.Submit(eqLength, func(start, end int) {
462+
for i := start; i < end; i++ {
463+
claims.eqs[levelWireI][i].Mul(&claims.eqs[levelWireI-1][i], &stride)
464+
}
465+
}, 512).Wait()
466+
levelWireI++
482467
alpha.Mul(&alpha, &stride)
483468
}
484469
}

internal/gkr/bls12-377/sumcheck.go

Lines changed: 28 additions & 43 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/gkr/bls12-381/sumcheck.go

Lines changed: 28 additions & 43 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)