Skip to content

Commit c6a0fdf

Browse files
bug: remove witness hinting (#39)
* bug: remove witness hinting * lint
1 parent dcd55d5 commit c6a0fdf

File tree

12 files changed

+196
-128
lines changed

12 files changed

+196
-128
lines changed

fri/fri.go

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"github.com/consensys/gnark/std/conversion"
1313
"github.com/consensys/gnark/std/lookup/logderivlookup"
1414
"github.com/consensys/gnark/std/math/bits"
15-
"github.com/consensys/gnark/std/math/cmp"
1615
"github.com/consensys/gnark/std/math/uints"
1716
)
1817

@@ -276,7 +275,7 @@ func (f *FriVerifier) verifyFirstLayer(queries []logderivlookup.Table, evaluatio
276275
decommitmentPositions := make([]logderivlookup.Table, 32)
277276
sparseEvaluationsFlattened := make([]m31.M31, 0)
278277
sparseEvaluations := make([]SparseEvaluations, 0)
279-
previousFriWitnessIndex := frontend.Variable(0)
278+
previousFriWitnessIndex := 0
280279

281280
columnBoundsIndex := 0
282281
maxLogSize := f.circuitData.ColumnBounds[0]
@@ -364,7 +363,11 @@ func (f *FriVerifier) verifyFirstLayer(queries []logderivlookup.Table, evaluatio
364363

365364
// verify the merkle decommitment
366365
merkleVerifier := NewMerkleVerifier(f.api, f.uapi, f.FirstLayerVerifier.proof.Commitment, columnLogSizes, nColumnsPerLogSize)
367-
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.FirstLayerVerifier.proof.Decommitment, queriesShape)
366+
firstLayerBranching := f.circuitData.FriFirstLayerBranching
367+
if len(firstLayerBranching) == 0 {
368+
panic("missing FRI first layer branching data")
369+
}
370+
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.FirstLayerVerifier.proof.Decommitment, queriesShape, firstLayerBranching)
368371

369372
return sparseEvaluations
370373
}
@@ -487,7 +490,11 @@ func (f *FriVerifier) verifyInnerLayers(queries []logderivlookup.Table, firstLay
487490
nColumnsPerLogSize[logSize+1] = 4
488491

489492
merkleVerifier := NewMerkleVerifier(f.api, f.uapi, f.InnerLayerVerifiers[innerLayerVerifier.layerIndex].proof.Commitment, columnLogSizes, nColumnsPerLogSize)
490-
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.InnerLayerVerifiers[innerLayerVerifier.layerIndex].proof.Decommitment, queryShape)
493+
if innerLayerVerifier.layerIndex >= len(f.circuitData.FriInnerLayerBranching) {
494+
panic("missing FRI inner layer branching data")
495+
}
496+
innerBranching := f.circuitData.FriInnerLayerBranching[innerLayerVerifier.layerIndex]
497+
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.InnerLayerVerifiers[innerLayerVerifier.layerIndex].proof.Decommitment, queryShape, innerBranching)
491498

492499
// currentLayerEvals contains g_{i-1}(x_j) folded
493500
currentLayerEvals = make([]m31.QM31, f.circuitData.DedupedQueriesShape[logSize])
@@ -535,23 +542,34 @@ func (f *FriVerifier) computeDecommitmentPositionsAndRebuildEvals(
535542
layerQueries logderivlookup.Table,
536543
evalAtQueries logderivlookup.Table,
537544
witnessEvals []m31.QM31,
538-
previousFriWitnessIndex frontend.Variable,
545+
previousFriWitnessIndex int,
539546
layerQueriesShape int,
540547
logSize int,
541-
) (logderivlookup.Table, []m31.M31, SparseEvaluations, frontend.Variable) {
548+
) (logderivlookup.Table, []m31.M31, SparseEvaluations, int) {
542549
layerDecommitmentPositions := logderivlookup.New(f.api)
543550
pairedEvalsFlattened := make([]m31.M31, 0)
544551
pairedEvals := make([][2]m31.QM31, 0)
545552
queryInitials := make([]uints.U32, 0)
546-
offset := frontend.Variable(0)
553+
offset := 0
547554
witnessIndex := previousFriWitnessIndex
555+
if logSize <= 0 {
556+
panic("log size must be positive for decommitment branching")
557+
}
558+
if logSize-1 >= len(f.circuitData.QueriesBranching) {
559+
panic("queries branching missing layer data")
560+
}
561+
branching := f.circuitData.QueriesBranching[logSize-1]
562+
if layerQueriesShape > len(branching) {
563+
panic("queries branching length mismatch")
564+
}
548565

549566
for i := 0; i < layerQueriesShape; i++ {
550-
base := f.api.Add(offset, frontend.Variable(i))
567+
base := offset + i
551568

552569
// get the query initial (query >> 1)
553-
leftQuery := layerQueries.Lookup(base)[0]
554-
rightQuery := layerQueries.Lookup(f.api.Add(base, frontend.Variable(1)))[0]
570+
leftQuery := layerQueries.Lookup(frontend.Variable(base))[0]
571+
rightQuery := layerQueries.Lookup(frontend.Variable(base + 1))[0]
572+
_ = rightQuery // keep lookup constraints even though branching is static
555573
queryU32 := f.uapi.ValueOf(leftQuery)
556574
queryInitialU32 := f.uapi.Rshift(queryU32, 1)
557575
queryInitial := f.uapi.ToValue(queryInitialU32)
@@ -562,40 +580,39 @@ func (f *FriVerifier) computeDecommitmentPositionsAndRebuildEvals(
562580
layerDecommitmentPositions.Insert(leftCandidate)
563581
layerDecommitmentPositions.Insert(rightCandidate)
564582

565-
isLeftQueried := cmp.IsEqual(f.api, leftQuery, leftCandidate)
566-
isRightQueried := cmp.IsEqual(f.api, rightQuery, rightCandidate)
583+
branchCode := branching[i]
584+
leftPresent := branchCode&1 == 1
585+
rightPresent := branchCode&2 == 2
567586

568-
// aggregate the arguments to comply with the hint signature
569-
args := []frontend.Variable{isLeftQueried, isRightQueried, witnessIndex}
570-
for _, witnessEval := range witnessEvals {
571-
args = append(args, witnessEval.AReal.Limb, witnessEval.AImag.Limb, witnessEval.BReal.Limb, witnessEval.BImag.Limb)
572-
}
573-
// the soundness relies on the fact that it is too costly to forge a valid witness for a given query, so it doesn't need to be checked
574-
hintedFriWitness, err := f.api.Compiler().NewHint(friWitnessHint, 4+1, args...)
575-
if err != nil {
576-
panic(err)
587+
witness := f.qm31Chip.Zero()
588+
if !leftPresent || !rightPresent {
589+
witness = witnessEvals[witnessIndex]
590+
witnessIndex++
577591
}
578-
witness := m31.NewQM31FromComponents(
579-
m31.NewM31Unchecked(hintedFriWitness[0]),
580-
m31.NewM31Unchecked(hintedFriWitness[1]),
581-
m31.NewM31Unchecked(hintedFriWitness[2]),
582-
m31.NewM31Unchecked(hintedFriWitness[3]),
583-
)
584-
witnessIndex = hintedFriWitness[4]
585592

586593
eval0Native := evalAtQueries.Lookup(base)[0]
587594
eval0 := f.qm31Chip.DecodeNative(eval0Native)
588-
eval1Native := evalAtQueries.Lookup(f.api.Add(base, frontend.Variable(1)))[0]
595+
eval1Native := evalAtQueries.Lookup(frontend.Variable(base + 1))[0]
589596
eval1 := f.qm31Chip.DecodeNative(eval1Native)
590597

591-
leftEval := f.qm31Chip.Select(isLeftQueried, eval0, witness)
592-
intermediate := f.qm31Chip.Select(isRightQueried, eval1, witness)
593-
rightEval := f.qm31Chip.Select(isLeftQueried, intermediate, eval0)
598+
var leftEval m31.QM31
599+
var rightEval m31.QM31
600+
if leftPresent {
601+
leftEval = eval0
602+
if rightPresent {
603+
rightEval = eval1
604+
} else {
605+
rightEval = witness
606+
}
607+
} else {
608+
leftEval = witness
609+
rightEval = eval0
610+
}
594611

595612
// update the offset
596-
offsetPlusOne := f.api.Add(offset, frontend.Variable(1))
597-
intermediate2 := f.api.Select(isRightQueried, offsetPlusOne, offset)
598-
offset = f.api.Select(isLeftQueried, intermediate2, offset)
613+
if leftPresent && rightPresent {
614+
offset++
615+
}
599616

600617
// flatten the evaluations into 4 M31 elements for use in the merkle decommitment verifier
601618
leftEvalComponents := leftEval.Components()

fri/utils.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,13 @@ func WitnessHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
5050
for i := 0; i < 32; i++ {
5151
results[32+i] = hashWitness[32*witnessIndex+i]
5252
}
53-
witnessIndex++
5453
}
5554
} else {
5655
if isJustRightQueried == 1 {
5756
// left witness needed
5857
for i := 0; i < 32; i++ {
5958
results[i] = hashWitness[32*witnessIndex+i]
6059
}
61-
witnessIndex++
6260
} else {
6361
// both witnesses needed
6462
for i := 0; i < 32; i++ {
@@ -67,13 +65,9 @@ func WitnessHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
6765
for i := 0; i < 32; i++ {
6866
results[32+i] = hashWitness[32*(witnessIndex+1)+i]
6967
}
70-
witnessIndex += 2
7168
}
7269
}
7370

74-
// set the witness index to the next unused hash witness entry
75-
results[64] = big.NewInt(int64(witnessIndex))
76-
7771
return nil
7872
}
7973

fri/vcs.go

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"github.com/HerodotusDev/stwo-gnark-verifier/variables"
88
"github.com/consensys/gnark/frontend"
99
"github.com/consensys/gnark/std/lookup/logderivlookup"
10-
"github.com/consensys/gnark/std/math/cmp"
1110
"github.com/consensys/gnark/std/math/uints"
1211
)
1312

@@ -58,27 +57,31 @@ func NewMerkleVerifier(api frontend.API, uapi *uints.BinaryField[uints.U32], roo
5857
// - queries[l] contains the queries for the layer of log size l, len(queries) should be the number of layers (from root to leaves).
5958
// For all l, queries[l] is sorted ascending with a dummy query at the end. IMPORTANT: queries[l] is never empty and should be generated from GenerateQueries.
6059
// - queriesShape[l] = len(queries[l]) - 1 (-1 for the dummy query)
61-
func (v *MerkleVerifier) Verify(queries []logderivlookup.Table, queriedValues []m31.M31, decommitment variables.MerkleDecommitment, queriesShape []int) {
60+
// - queriesBranching[l][i] encodes which children are present for queries[l][i]:
61+
// bit 0 for left, bit 1 for right
62+
func (v *MerkleVerifier) Verify(queries []logderivlookup.Table, queriedValues []m31.M31, decommitment variables.MerkleDecommitment, queriesShape []int, queriesBranching [][]uint8) {
6263
remainingValues := queriedValues
63-
// there is no constraints on witnessIndex, it's just used as an external pointer for the hint function (frontend.Variable is
64-
// not needed but convenient for hint signature)
65-
witnessIndex := frontend.Variable(0)
64+
witnessIndex := 0
65+
zeroHash := [32]uints.U8{}
66+
for i := range zeroHash {
67+
zeroHash[i] = uints.NewU8(0)
68+
}
6669

6770
// storage for the hashes per layer, keyed by log size
6871
layerHashes := make([]logderivlookup.Table, v.maxLogSize+1)
69-
7072
// decommit layer by layer, doing all queries at once
7173
for layerLog := v.maxLogSize; ; layerLog-- {
7274
// initialize the layer hashes
7375
layerHashes[layerLog] = logderivlookup.New(v.api)
7476
// get the number of columns in the layer
7577
nColumnsInLayer := v.nColumnsPerLogSize[layerLog]
7678
// j is a pointer to the previous layer query.
77-
j := frontend.Variable(0)
79+
j := 0
7880

7981
// go through all query positions of the current layer
8082
for queryIndex := 0; queryIndex < queriesShape[layerLog]; queryIndex++ {
8183
query := queries[layerLog].Lookup(queryIndex)[0]
84+
_ = query // keep lookup constraints even though branching is static
8285
// pop the front of the queried values if any
8386
var columnValues []m31.M31
8487
if nColumnsInLayer > 0 {
@@ -93,72 +96,61 @@ func (v *MerkleVerifier) Verify(queries []logderivlookup.Table, queriedValues []
9396
layerHashes[layerLog].Insert(lo)
9497
layerHashes[layerLog].Insert(hi)
9598
} else {
96-
jPlusOne := v.api.Add(j, frontend.Variable(1))
97-
jPlusTwo := v.api.Add(j, frontend.Variable(2))
98-
twoJ := v.api.Mul(j, frontend.Variable(2))
99-
twoJPlusOne := v.api.Add(twoJ, frontend.Variable(1))
100-
twoJPlusTwo := v.api.Add(twoJ, frontend.Variable(2))
101-
twoJPlusThree := v.api.Add(twoJ, frontend.Variable(3))
102-
103-
// derive the children queries candidates
104-
queryMulTwo := v.api.Mul(query, frontend.Variable(2))
105-
leftCandidate := queryMulTwo
106-
rightCandidate := v.api.Add(queryMulTwo, frontend.Variable(1))
99+
if layerLog >= len(queriesBranching) {
100+
panic("queries branching missing layer data")
101+
}
102+
if queryIndex >= len(queriesBranching[layerLog]) {
103+
panic("queries branching length mismatch")
104+
}
105+
branchCode := queriesBranching[layerLog][queryIndex]
106+
leftPresent := branchCode&1 == 1
107+
rightPresent := branchCode&2 == 2
107108

108109
var leftHash [32]uints.U8
109110
var rightHash [32]uints.U8
110111

111112
// rebuild the children hashes candidates from the previous layer
112-
h0Lo := layerHashes[layerLog+1].Lookup(twoJ)[0]
113-
h0Hi := layerHashes[layerLog+1].Lookup(twoJPlusOne)[0]
114-
h1Lo := layerHashes[layerLog+1].Lookup(twoJPlusTwo)[0]
115-
h1Hi := layerHashes[layerLog+1].Lookup(twoJPlusThree)[0]
113+
twoJ := 2 * j
114+
h0Lo := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ))[0]
115+
h0Hi := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ + 1))[0]
116+
h1Lo := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ + 2))[0]
117+
h1Hi := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ + 3))[0]
116118
h0 := utils.RebuildHash(v.api, h0Lo, h0Hi)
117119
h1 := utils.RebuildHash(v.api, h1Lo, h1Hi)
118120

119-
isLeftQueried := cmp.IsEqual(v.api, leftCandidate, queries[layerLog+1].Lookup(j)[0])
120-
isRightAlsoQueried := cmp.IsEqual(v.api, rightCandidate, queries[layerLog+1].Lookup(jPlusOne)[0])
121-
isJustRightQueried := cmp.IsEqual(v.api, rightCandidate, queries[layerLog+1].Lookup(j)[0])
122-
123-
// aggregate the arguments to comply with the hint signature
124-
args := []frontend.Variable{isLeftQueried, isRightAlsoQueried, isJustRightQueried, witnessIndex}
125-
for _, hash := range decommitment.HashWitness {
126-
hashNative := [32]frontend.Variable{}
127-
for i := 0; i < 32; i++ {
128-
hashNative[i] = v.bapi.Value(hash[i])
129-
}
130-
args = append(args, hashNative[:]...)
131-
}
132-
// the soundness relies on the fact that it is too costly to forge a valid witness for a given query, so it doesn't need to be checked
133-
hintedWitness, err := v.api.Compiler().NewHint(WitnessHint, 2*32+1, args...)
134-
if err != nil {
135-
panic(err)
136-
}
137-
138-
// extract the witness values from the hinted witness
139121
var w0 [32]uints.U8
140122
var w1 [32]uints.U8
141-
for i := 0; i < 32; i++ {
142-
w0[i] = v.bapi.ValueOf(hintedWitness[i])
143-
}
144-
for i := 0; i < 32; i++ {
145-
w1[i] = v.bapi.ValueOf(hintedWitness[32+i])
123+
if leftPresent {
124+
leftHash = h0
125+
if rightPresent {
126+
rightHash = h1
127+
} else {
128+
w1 = decommitment.HashWitness[witnessIndex]
129+
witnessIndex++
130+
rightHash = w1
131+
}
132+
} else if rightPresent {
133+
w0 = decommitment.HashWitness[witnessIndex]
134+
witnessIndex++
135+
leftHash = w0
136+
rightHash = h0
137+
} else {
138+
w0 = decommitment.HashWitness[witnessIndex]
139+
w1 = decommitment.HashWitness[witnessIndex+1]
140+
witnessIndex += 2
141+
leftHash = w0
142+
rightHash = w1
146143
}
147144

148-
// update the witness index from the hinted witness
149-
witnessIndex = hintedWitness[2*32]
150-
151-
leftHash = utils.SelectHash(v.api, isLeftQueried, h0, w0)
152-
intermediate1 := utils.SelectHash(v.api, isRightAlsoQueried, h1, w1)
153-
intermediate2 := utils.SelectHash(v.api, isJustRightQueried, h0, w1)
154-
rightHash = utils.SelectHash(v.api, isLeftQueried, intermediate1, intermediate2)
155-
156-
// update the pointer to the previous layer query
157-
j = v.api.Select(
158-
isLeftQueried,
159-
v.api.Select(isRightAlsoQueried, jPlusTwo, jPlusOne),
160-
v.api.Select(isJustRightQueried, jPlusOne, j),
161-
)
145+
if leftPresent {
146+
if rightPresent {
147+
j += 2
148+
} else {
149+
j++
150+
}
151+
} else if rightPresent {
152+
j++
153+
}
162154

163155
// update the current layer hashes
164156
hash := v.blake2sChip.HashNode(leftHash[:], rightHash[:], columnValues)

0 commit comments

Comments
 (0)