Skip to content

Commit 6e69608

Browse files
authored
perf: GKR Levels (#1735)
1 parent fb118be commit 6e69608

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+5787
-5165
lines changed

constraint/gkr.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package constraint
2+
3+
type (
4+
// GkrClaimSource identifies an incoming evaluation claim for a wire.
5+
// Level is the level that produced the claim.
6+
// OutgoingClaimIndex selects which of that level's outgoing evaluation points is referenced;
7+
// always 0 for SumcheckLevels, 0..M-1 for SkipLevels with M inherited evaluation points.
8+
// The initial verifier challenge is represented as {Level: len(schedule), OutgoingClaimIndex: 0}.
9+
GkrClaimSource struct {
10+
Level int `json:"level"`
11+
OutgoingClaimIndex int `json:"outgoingClaimIndex"`
12+
}
13+
14+
// GkrClaimGroup represents a set of wires sharing identical claim sources.
15+
// finalEvalProof index = pos(wire, srcLevel) * NbOutgoingEvalPoints(srcLevel) + ClaimSources[claimI].OutgoingClaimIndex,
16+
// where pos(wire, srcLevel) is the wire's position in srcLevel's UniqueGateInputs list.
17+
GkrClaimGroup struct {
18+
Wires []int `json:"wires"`
19+
ClaimSources []GkrClaimSource `json:"claimSources"`
20+
}
21+
22+
// GkrProvingLevel is a single level in the proving schedule.
23+
GkrProvingLevel interface {
24+
NbOutgoingEvalPoints() int
25+
// NbClaims returns the total number of claims at this level.
26+
NbClaims() int
27+
ClaimGroups() []GkrClaimGroup
28+
// FinalEvalProofIndex returns where to find the evaluationPointI'th evaluation claim for the wireI'th input wire to the layer,
29+
// in the layer's final evaluation proof.
30+
FinalEvalProofIndex(wireI, evaluationPointI int) int
31+
}
32+
33+
// GkrSkipLevel represents a level where zerocheck is skipped.
34+
// Claims propagate through at their existing evaluation points.
35+
GkrSkipLevel GkrClaimGroup
36+
37+
// GkrSumcheckLevel represents a level where one or more zerochecks are batched
38+
// together in a single sumcheck. Each GkrClaimGroup within may have different
39+
// claim sources (sumcheck-level batching), or the same source (enabling
40+
// zerocheck-level batching with shared eq tables).
41+
GkrSumcheckLevel []GkrClaimGroup
42+
43+
// GkrProvingSchedule is a sequence of levels defining how to prove a GKR circuit.
44+
GkrProvingSchedule []GkrProvingLevel
45+
)
46+
47+
func (g GkrClaimGroup) NbClaims() int { return len(g.Wires) * len(g.ClaimSources) }
48+
49+
func (l GkrSumcheckLevel) NbOutgoingEvalPoints() int { return 1 }
50+
func (l GkrSumcheckLevel) NbClaims() int {
51+
n := 0
52+
for _, g := range l {
53+
n += len(g.Wires) * len(g.ClaimSources)
54+
}
55+
return n
56+
}
57+
func (l GkrSumcheckLevel) ClaimGroups() []GkrClaimGroup { return l }
58+
func (l GkrSumcheckLevel) FinalEvalProofIndex(wireI, _ int) int { return wireI }
59+
60+
func (l GkrSkipLevel) NbOutgoingEvalPoints() int { return len(l.ClaimSources) }
61+
func (l GkrSkipLevel) NbClaims() int {
62+
return GkrClaimGroup(l).NbClaims()
63+
}
64+
func (l GkrSkipLevel) ClaimGroups() []GkrClaimGroup { return []GkrClaimGroup{GkrClaimGroup(l)} }
65+
func (l GkrSkipLevel) FinalEvalProofIndex(wireI, evaluationPointI int) int {
66+
return wireI*l.NbOutgoingEvalPoints() + evaluationPointI
67+
}
68+
69+
// BindGkrFinalEvalProof binds the non-input-wire entries of finalEvalProof into the transcript.
70+
// Input-wire evaluations are fully determined by the public assignment (and by evaluation points
71+
// already committed to the transcript), so hashing them contributes nothing to Fiat-Shamir security.
72+
// uniqueGateInputs is the deduplicated list of gate-input wire indices for the level in the same
73+
// order as the finalEvalProof entries (i.e. the order returned by UniqueGateInputs).
74+
func BindGkrFinalEvalProof[F any](transcript interface{ Bind(...F) }, finalEvalProof []F, uniqueGateInputs []int, isInput func(wireI int) bool, level GkrProvingLevel) {
75+
for i, inputWireI := range uniqueGateInputs {
76+
if !isInput(inputWireI) {
77+
transcript.Bind(finalEvalProof[level.FinalEvalProofIndex(i, 0):level.FinalEvalProofIndex(i+1, 0)]...)
78+
}
79+
}
80+
}

constraint/marshal.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,9 @@ func getTagSet() cbor.TagSet {
398398
addType(reflect.TypeOf(BlueprintBatchInverse[U32]{}))
399399
addType(reflect.TypeOf(BlueprintBatchInverse[U64]{}))
400400

401+
addType(reflect.TypeOf(GkrSkipLevel{}))
402+
addType(reflect.TypeOf(GkrSumcheckLevel{}))
403+
401404
// Add types registered by external packages (e.g., GKR blueprints)
402405
// These use explicit tag numbers to ensure stability regardless of init() order
403406
for _, rt := range registeredBlueprintTypes {

constraint/solver/gkrgates/registry.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"runtime"
1111

1212
"github.com/consensys/gnark-crypto/ecc"
13-
"github.com/consensys/gnark/internal/gkr/gkrtypes"
13+
"github.com/consensys/gnark/internal/gkr/gkrcore"
1414
"github.com/consensys/gnark/std/gkrapi/gkr"
1515
)
1616

@@ -119,7 +119,7 @@ func Register(f gkr.GateFunction, nbIn int, options ...RegisterOption) error {
119119
}
120120

121121
for _, curve := range s.curves {
122-
compiled, err := gkrtypes.CompileGateFunction(f, nbIn, curve.ScalarField())
122+
compiled, err := gkrcore.CompileGateFunction(f, nbIn, curve.ScalarField())
123123
if err != nil {
124124
return err
125125
}

internal/generator/backend/template/gkr/blueprint.go.tmpl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ import (
88
"github.com/consensys/gnark-crypto/ecc"
99
"{{ .FieldPackagePath }}"
1010
"{{ .FieldPackagePath }}/polynomial"
11-
fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
1211
"github.com/consensys/gnark-crypto/hash"
1312
"github.com/consensys/gnark/constraint"
14-
"github.com/consensys/gnark/internal/gkr/gkrtypes"
13+
"github.com/consensys/gnark/internal/gkr/gkrcore"
1514
)
1615

1716
func init() {
@@ -27,7 +26,7 @@ type circuitEvaluator struct {
2726
// BlueprintSolve is a {{.FieldID}}-specific blueprint for solving GKR circuit instances.
2827
type BlueprintSolve struct {
2928
// Circuit structure (serialized)
30-
Circuit gkrtypes.SerializableCircuit
29+
Circuit gkrcore.SerializableCircuit
3130
NbInstances uint32
3231

3332
// Not serialized - recreated lazily at solve time
@@ -41,10 +40,13 @@ type BlueprintSolve struct {
4140

4241
// Ensures BlueprintSolve implements BlueprintStateful
4342
var _ constraint.BlueprintStateful[constraint.U64] = (*BlueprintSolve)(nil)
43+
4444
// Equal returns true if the serialized fields of two BlueprintSolve are equal.
4545
// Used for testing serialization round-trips.
4646
func (b *BlueprintSolve) Equal(other constraint.BlueprintComparable) bool {
47-
if other == nil { return false }
47+
if other == nil {
48+
return false
49+
}
4850
o, ok := other.(*BlueprintSolve)
4951
if !ok {
5052
return false
@@ -107,7 +109,7 @@ func (b *BlueprintSolve) Solve(s constraint.Solver[constraint.U64], inst constra
107109
if w.IsInput() {
108110
val, delta := s.Read(calldata)
109111
calldata = calldata[delta:]
110-
// Copy directly from constraint.U64 to fr.Element (both in Montgomery form)
112+
// Copy directly from constraint.U64 to {{ .ElementType }} (both in Montgomery form)
111113
copy(b.assignments[wI][instanceI][:], val[:])
112114
} else {
113115
// Get evaluator for this wire from the circuit evaluator
@@ -123,7 +125,7 @@ func (b *BlueprintSolve) Solve(s constraint.Solver[constraint.U64], inst constra
123125
}
124126
}
125127

126-
// Set output wires (copy fr.Element to U64 in Montgomery form)
128+
// Set output wires (copy {{ .ElementType }} to U64 in Montgomery form)
127129
for outI, outWI := range b.outputWires {
128130
var val constraint.U64
129131
copy(val[:], b.assignments[outWI][instanceI][:])
@@ -150,9 +152,9 @@ func (b *BlueprintSolve) NbConstraints() int {
150152

151153
// NbOutputs implements Blueprint
152154
func (b *BlueprintSolve) NbOutputs(inst constraint.Instruction) int {
153-
if b.outputWires == nil {
154-
b.outputWires = b.Circuit.Outputs()
155-
}
155+
if b.outputWires == nil {
156+
b.outputWires = b.Circuit.Outputs()
157+
}
156158
return len(b.outputWires)
157159
}
158160

@@ -194,21 +196,25 @@ func (b *BlueprintSolve) UpdateInstructionTree(inst constraint.Instruction, tree
194196
type BlueprintProve struct {
195197
SolveBlueprintID constraint.BlueprintID
196198
SolveBlueprint *BlueprintSolve `cbor:"-"` // not serialized, set at compile time
199+
Schedule constraint.GkrProvingSchedule
197200
HashName string
198201

199202
lock sync.Mutex
200203
}
201204

202205
// Ensures BlueprintProve implements BlueprintSolvable
203206
var _ constraint.BlueprintSolvable[constraint.U64] = (*BlueprintProve)(nil)
207+
204208
// Equal returns true if the serialized fields of two BlueprintProve are equal.
205209
func (b *BlueprintProve) Equal(other constraint.BlueprintComparable) bool {
206-
if other == nil { return false }
210+
if other == nil {
211+
return false
212+
}
207213
o, ok := other.(*BlueprintProve)
208214
if !ok {
209215
return false
210216
}
211-
return b.SolveBlueprintID == o.SolveBlueprintID && b.HashName == o.HashName
217+
return b.SolveBlueprintID == o.SolveBlueprintID && b.HashName == o.HashName && reflect.DeepEqual(b.Schedule, o.Schedule)
212218
}
213219

214220
// Solve implements the BlueprintSolvable interface for proving.
@@ -243,28 +249,27 @@ func (b *BlueprintProve) Solve(s constraint.Solver[constraint.U64], inst constra
243249
}
244250
}
245251

252+
// Create hasher and write base challenges
253+
hsh := hash.NewHash(b.HashName + "_{{.FieldID}}")
254+
246255
// Read initial challenges from instruction calldata (parse dynamically, no metadata)
247256
// Format: [0]=totalSize, [1...]=challenge linear expressions
248-
insBytes := make([][]byte, 0) // first challenges
249257
calldata := inst.Calldata[1:] // skip size prefix
250258
for len(calldata) != 0 {
251259
val, delta := s.Read(calldata)
252260
calldata = calldata[delta:]
253261

254-
// Copy directly from constraint.U64 to fr.Element (both in Montgomery form)
262+
// Copy directly from constraint.U64 to {{ .ElementType }} (both in Montgomery form)
255263
var challenge {{ .ElementType }}
256264
copy(challenge[:], val[:])
257-
insBytes = append(insBytes, challenge.Marshal())
265+
challengeBytes := challenge.Bytes()
266+
hsh.Write(challengeBytes[:])
258267
}
259268

260-
// Create Fiat-Shamir settings
261-
hsh := hash.NewHash(b.HashName + "_{{.FieldID}}")
262-
fsSettings := fiatshamir.WithHash(hsh, insBytes...)
263-
264269
// Call the {{.FieldID}}-specific Prove function (assignments already WireAssignment type)
265-
proof, err := Prove(solveBlueprint.Circuit, assignments, fsSettings)
270+
proof, err := Prove(solveBlueprint.Circuit, b.Schedule, assignments, hsh)
266271
if err != nil {
267-
return fmt.Errorf("{{toLower .FieldID}} prove failed: %w", err)
272+
return fmt.Errorf("{{.FieldID}} prove failed: %w", err)
268273
}
269274

270275
for i, elem := range proof.flatten() {
@@ -292,7 +297,7 @@ func (b *BlueprintProve) proofSize() int {
292297
}
293298
nbPaddedInstances := ecc.NextPowerOfTwo(uint64(b.SolveBlueprint.NbInstances))
294299
logNbInstances := bits.TrailingZeros64(nbPaddedInstances)
295-
return b.SolveBlueprint.Circuit.ProofSize(logNbInstances)
300+
return b.SolveBlueprint.Circuit.ProofSize(b.Schedule, logNbInstances)
296301
}
297302

298303
// NbOutputs implements Blueprint
@@ -344,9 +349,12 @@ type BlueprintGetAssignment struct {
344349

345350
// Ensures BlueprintGetAssignment implements BlueprintSolvable
346351
var _ constraint.BlueprintSolvable[constraint.U64] = (*BlueprintGetAssignment)(nil)
352+
347353
// Equal returns true if the serialized fields of two BlueprintGetAssignment are equal.
348354
func (b *BlueprintGetAssignment) Equal(other constraint.BlueprintComparable) bool {
349-
if other == nil { return false }
355+
if other == nil {
356+
return false
357+
}
350358
o, ok := other.(*BlueprintGetAssignment)
351359
if !ok {
352360
return false
@@ -418,7 +426,7 @@ func (b *BlueprintGetAssignment) UpdateInstructionTree(inst constraint.Instructi
418426
}
419427

420428
// NewBlueprints creates and registers all GKR blueprints for {{.FieldID}}
421-
func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compiler constraint.CustomizableSystem) gkrtypes.Blueprints {
429+
func NewBlueprints(circuit gkrcore.SerializableCircuit, schedule constraint.GkrProvingSchedule, hashName string, compiler constraint.CustomizableSystem) gkrcore.Blueprints {
422430
// Create and register solve blueprint
423431
solve := &BlueprintSolve{Circuit: circuit}
424432
solveID := compiler.AddBlueprint(solve)
@@ -427,6 +435,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil
427435
prove := &BlueprintProve{
428436
SolveBlueprintID: solveID,
429437
SolveBlueprint: solve,
438+
Schedule: schedule,
430439
HashName: hashName,
431440
}
432441
proveID := compiler.AddBlueprint(prove)
@@ -437,7 +446,7 @@ func NewBlueprints(circuit gkrtypes.SerializableCircuit, hashName string, compil
437446
}
438447
getAssignmentID := compiler.AddBlueprint(getAssignment)
439448

440-
return gkrtypes.Blueprints{
449+
return gkrcore.Blueprints{
441450
SolveID: solveID,
442451
Solve: solve,
443452
ProveID: proveID,

0 commit comments

Comments
 (0)