Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 52 additions & 35 deletions fri/fri.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/consensys/gnark/std/conversion"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/std/math/uints"
)

Expand Down Expand Up @@ -276,7 +275,7 @@ func (f *FriVerifier) verifyFirstLayer(queries []logderivlookup.Table, evaluatio
decommitmentPositions := make([]logderivlookup.Table, 32)
sparseEvaluationsFlattened := make([]m31.M31, 0)
sparseEvaluations := make([]SparseEvaluations, 0)
previousFriWitnessIndex := frontend.Variable(0)
previousFriWitnessIndex := 0

columnBoundsIndex := 0
maxLogSize := f.circuitData.ColumnBounds[0]
Expand Down Expand Up @@ -364,7 +363,11 @@ func (f *FriVerifier) verifyFirstLayer(queries []logderivlookup.Table, evaluatio

// verify the merkle decommitment
merkleVerifier := NewMerkleVerifier(f.api, f.uapi, f.FirstLayerVerifier.proof.Commitment, columnLogSizes, nColumnsPerLogSize)
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.FirstLayerVerifier.proof.Decommitment, queriesShape)
firstLayerBranching := f.circuitData.FriFirstLayerBranching
if len(firstLayerBranching) == 0 {
panic("missing FRI first layer branching data")
}
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.FirstLayerVerifier.proof.Decommitment, queriesShape, firstLayerBranching)

return sparseEvaluations
}
Expand Down Expand Up @@ -487,7 +490,11 @@ func (f *FriVerifier) verifyInnerLayers(queries []logderivlookup.Table, firstLay
nColumnsPerLogSize[logSize+1] = 4

merkleVerifier := NewMerkleVerifier(f.api, f.uapi, f.InnerLayerVerifiers[innerLayerVerifier.layerIndex].proof.Commitment, columnLogSizes, nColumnsPerLogSize)
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.InnerLayerVerifiers[innerLayerVerifier.layerIndex].proof.Decommitment, queryShape)
if innerLayerVerifier.layerIndex >= len(f.circuitData.FriInnerLayerBranching) {
panic("missing FRI inner layer branching data")
}
innerBranching := f.circuitData.FriInnerLayerBranching[innerLayerVerifier.layerIndex]
merkleVerifier.Verify(decommitmentPositions, sparseEvaluationsFlattened, f.InnerLayerVerifiers[innerLayerVerifier.layerIndex].proof.Decommitment, queryShape, innerBranching)

// currentLayerEvals contains g_{i-1}(x_j) folded
currentLayerEvals = make([]m31.QM31, f.circuitData.DedupedQueriesShape[logSize])
Expand Down Expand Up @@ -535,23 +542,34 @@ func (f *FriVerifier) computeDecommitmentPositionsAndRebuildEvals(
layerQueries logderivlookup.Table,
evalAtQueries logderivlookup.Table,
witnessEvals []m31.QM31,
previousFriWitnessIndex frontend.Variable,
previousFriWitnessIndex int,
layerQueriesShape int,
logSize int,
) (logderivlookup.Table, []m31.M31, SparseEvaluations, frontend.Variable) {
) (logderivlookup.Table, []m31.M31, SparseEvaluations, int) {
layerDecommitmentPositions := logderivlookup.New(f.api)
pairedEvalsFlattened := make([]m31.M31, 0)
pairedEvals := make([][2]m31.QM31, 0)
queryInitials := make([]uints.U32, 0)
offset := frontend.Variable(0)
offset := 0
witnessIndex := previousFriWitnessIndex
if logSize <= 0 {
panic("log size must be positive for decommitment branching")
}
if logSize-1 >= len(f.circuitData.QueriesBranching) {
panic("queries branching missing layer data")
}
branching := f.circuitData.QueriesBranching[logSize-1]
if layerQueriesShape > len(branching) {
panic("queries branching length mismatch")
}

for i := 0; i < layerQueriesShape; i++ {
base := f.api.Add(offset, frontend.Variable(i))
base := offset + i

// get the query initial (query >> 1)
leftQuery := layerQueries.Lookup(base)[0]
rightQuery := layerQueries.Lookup(f.api.Add(base, frontend.Variable(1)))[0]
leftQuery := layerQueries.Lookup(frontend.Variable(base))[0]
rightQuery := layerQueries.Lookup(frontend.Variable(base + 1))[0]
_ = rightQuery // keep lookup constraints even though branching is static
queryU32 := f.uapi.ValueOf(leftQuery)
queryInitialU32 := f.uapi.Rshift(queryU32, 1)
queryInitial := f.uapi.ToValue(queryInitialU32)
Expand All @@ -562,40 +580,39 @@ func (f *FriVerifier) computeDecommitmentPositionsAndRebuildEvals(
layerDecommitmentPositions.Insert(leftCandidate)
layerDecommitmentPositions.Insert(rightCandidate)

isLeftQueried := cmp.IsEqual(f.api, leftQuery, leftCandidate)
isRightQueried := cmp.IsEqual(f.api, rightQuery, rightCandidate)
branchCode := branching[i]
leftPresent := branchCode&1 == 1
rightPresent := branchCode&2 == 2

// aggregate the arguments to comply with the hint signature
args := []frontend.Variable{isLeftQueried, isRightQueried, witnessIndex}
for _, witnessEval := range witnessEvals {
args = append(args, witnessEval.AReal.Limb, witnessEval.AImag.Limb, witnessEval.BReal.Limb, witnessEval.BImag.Limb)
}
// the soundness relies on the fact that it is too costly to forge a valid witness for a given query, so it doesn't need to be checked
hintedFriWitness, err := f.api.Compiler().NewHint(friWitnessHint, 4+1, args...)
if err != nil {
panic(err)
witness := f.qm31Chip.Zero()
if !leftPresent || !rightPresent {
witness = witnessEvals[witnessIndex]
witnessIndex++
}
witness := m31.NewQM31FromComponents(
m31.NewM31Unchecked(hintedFriWitness[0]),
m31.NewM31Unchecked(hintedFriWitness[1]),
m31.NewM31Unchecked(hintedFriWitness[2]),
m31.NewM31Unchecked(hintedFriWitness[3]),
)
witnessIndex = hintedFriWitness[4]

eval0Native := evalAtQueries.Lookup(base)[0]
eval0 := f.qm31Chip.DecodeNative(eval0Native)
eval1Native := evalAtQueries.Lookup(f.api.Add(base, frontend.Variable(1)))[0]
eval1Native := evalAtQueries.Lookup(frontend.Variable(base + 1))[0]
eval1 := f.qm31Chip.DecodeNative(eval1Native)

leftEval := f.qm31Chip.Select(isLeftQueried, eval0, witness)
intermediate := f.qm31Chip.Select(isRightQueried, eval1, witness)
rightEval := f.qm31Chip.Select(isLeftQueried, intermediate, eval0)
var leftEval m31.QM31
var rightEval m31.QM31
if leftPresent {
leftEval = eval0
if rightPresent {
rightEval = eval1
} else {
rightEval = witness
}
} else {
leftEval = witness
rightEval = eval0
}

// update the offset
offsetPlusOne := f.api.Add(offset, frontend.Variable(1))
intermediate2 := f.api.Select(isRightQueried, offsetPlusOne, offset)
offset = f.api.Select(isLeftQueried, intermediate2, offset)
if leftPresent && rightPresent {
offset++
}

// flatten the evaluations into 4 M31 elements for use in the merkle decommitment verifier
leftEvalComponents := leftEval.Components()
Expand Down
6 changes: 0 additions & 6 deletions fri/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,13 @@ func WitnessHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
for i := 0; i < 32; i++ {
results[32+i] = hashWitness[32*witnessIndex+i]
}
witnessIndex++
}
} else {
if isJustRightQueried == 1 {
// left witness needed
for i := 0; i < 32; i++ {
results[i] = hashWitness[32*witnessIndex+i]
}
witnessIndex++
} else {
// both witnesses needed
for i := 0; i < 32; i++ {
Expand All @@ -67,13 +65,9 @@ func WitnessHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
for i := 0; i < 32; i++ {
results[32+i] = hashWitness[32*(witnessIndex+1)+i]
}
witnessIndex += 2
}
}

// set the witness index to the next unused hash witness entry
results[64] = big.NewInt(int64(witnessIndex))

return nil
}

Expand Down
114 changes: 53 additions & 61 deletions fri/vcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/HerodotusDev/stwo-gnark-verifier/variables"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/cmp"
"github.com/consensys/gnark/std/math/uints"
)

Expand Down Expand Up @@ -58,27 +57,31 @@ func NewMerkleVerifier(api frontend.API, uapi *uints.BinaryField[uints.U32], roo
// - queries[l] contains the queries for the layer of log size l, len(queries) should be the number of layers (from root to leaves).
// For all l, queries[l] is sorted ascending with a dummy query at the end. IMPORTANT: queries[l] is never empty and should be generated from GenerateQueries.
// - queriesShape[l] = len(queries[l]) - 1 (-1 for the dummy query)
func (v *MerkleVerifier) Verify(queries []logderivlookup.Table, queriedValues []m31.M31, decommitment variables.MerkleDecommitment, queriesShape []int) {
// - queriesBranching[l][i] encodes which children are present for queries[l][i]:
// bit 0 for left, bit 1 for right
func (v *MerkleVerifier) Verify(queries []logderivlookup.Table, queriedValues []m31.M31, decommitment variables.MerkleDecommitment, queriesShape []int, queriesBranching [][]uint8) {
remainingValues := queriedValues
// there is no constraints on witnessIndex, it's just used as an external pointer for the hint function (frontend.Variable is
// not needed but convenient for hint signature)
witnessIndex := frontend.Variable(0)
witnessIndex := 0
zeroHash := [32]uints.U8{}
for i := range zeroHash {
zeroHash[i] = uints.NewU8(0)
}

// storage for the hashes per layer, keyed by log size
layerHashes := make([]logderivlookup.Table, v.maxLogSize+1)

// decommit layer by layer, doing all queries at once
for layerLog := v.maxLogSize; ; layerLog-- {
// initialize the layer hashes
layerHashes[layerLog] = logderivlookup.New(v.api)
// get the number of columns in the layer
nColumnsInLayer := v.nColumnsPerLogSize[layerLog]
// j is a pointer to the previous layer query.
j := frontend.Variable(0)
j := 0

// go through all query positions of the current layer
for queryIndex := 0; queryIndex < queriesShape[layerLog]; queryIndex++ {
query := queries[layerLog].Lookup(queryIndex)[0]
_ = query // keep lookup constraints even though branching is static
// pop the front of the queried values if any
var columnValues []m31.M31
if nColumnsInLayer > 0 {
Expand All @@ -93,72 +96,61 @@ func (v *MerkleVerifier) Verify(queries []logderivlookup.Table, queriedValues []
layerHashes[layerLog].Insert(lo)
layerHashes[layerLog].Insert(hi)
} else {
jPlusOne := v.api.Add(j, frontend.Variable(1))
jPlusTwo := v.api.Add(j, frontend.Variable(2))
twoJ := v.api.Mul(j, frontend.Variable(2))
twoJPlusOne := v.api.Add(twoJ, frontend.Variable(1))
twoJPlusTwo := v.api.Add(twoJ, frontend.Variable(2))
twoJPlusThree := v.api.Add(twoJ, frontend.Variable(3))

// derive the children queries candidates
queryMulTwo := v.api.Mul(query, frontend.Variable(2))
leftCandidate := queryMulTwo
rightCandidate := v.api.Add(queryMulTwo, frontend.Variable(1))
if layerLog >= len(queriesBranching) {
panic("queries branching missing layer data")
}
if queryIndex >= len(queriesBranching[layerLog]) {
panic("queries branching length mismatch")
}
branchCode := queriesBranching[layerLog][queryIndex]
leftPresent := branchCode&1 == 1
rightPresent := branchCode&2 == 2

var leftHash [32]uints.U8
var rightHash [32]uints.U8

// rebuild the children hashes candidates from the previous layer
h0Lo := layerHashes[layerLog+1].Lookup(twoJ)[0]
h0Hi := layerHashes[layerLog+1].Lookup(twoJPlusOne)[0]
h1Lo := layerHashes[layerLog+1].Lookup(twoJPlusTwo)[0]
h1Hi := layerHashes[layerLog+1].Lookup(twoJPlusThree)[0]
twoJ := 2 * j
h0Lo := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ))[0]
h0Hi := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ + 1))[0]
h1Lo := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ + 2))[0]
h1Hi := layerHashes[layerLog+1].Lookup(frontend.Variable(twoJ + 3))[0]
h0 := utils.RebuildHash(v.api, h0Lo, h0Hi)
h1 := utils.RebuildHash(v.api, h1Lo, h1Hi)

isLeftQueried := cmp.IsEqual(v.api, leftCandidate, queries[layerLog+1].Lookup(j)[0])
isRightAlsoQueried := cmp.IsEqual(v.api, rightCandidate, queries[layerLog+1].Lookup(jPlusOne)[0])
isJustRightQueried := cmp.IsEqual(v.api, rightCandidate, queries[layerLog+1].Lookup(j)[0])

// aggregate the arguments to comply with the hint signature
args := []frontend.Variable{isLeftQueried, isRightAlsoQueried, isJustRightQueried, witnessIndex}
for _, hash := range decommitment.HashWitness {
hashNative := [32]frontend.Variable{}
for i := 0; i < 32; i++ {
hashNative[i] = v.bapi.Value(hash[i])
}
args = append(args, hashNative[:]...)
}
// the soundness relies on the fact that it is too costly to forge a valid witness for a given query, so it doesn't need to be checked
hintedWitness, err := v.api.Compiler().NewHint(WitnessHint, 2*32+1, args...)
if err != nil {
panic(err)
}

// extract the witness values from the hinted witness
var w0 [32]uints.U8
var w1 [32]uints.U8
for i := 0; i < 32; i++ {
w0[i] = v.bapi.ValueOf(hintedWitness[i])
}
for i := 0; i < 32; i++ {
w1[i] = v.bapi.ValueOf(hintedWitness[32+i])
if leftPresent {
leftHash = h0
if rightPresent {
rightHash = h1
} else {
w1 = decommitment.HashWitness[witnessIndex]
witnessIndex++
rightHash = w1
}
} else if rightPresent {
w0 = decommitment.HashWitness[witnessIndex]
witnessIndex++
leftHash = w0
rightHash = h0
} else {
w0 = decommitment.HashWitness[witnessIndex]
w1 = decommitment.HashWitness[witnessIndex+1]
witnessIndex += 2
leftHash = w0
rightHash = w1
}

// update the witness index from the hinted witness
witnessIndex = hintedWitness[2*32]

leftHash = utils.SelectHash(v.api, isLeftQueried, h0, w0)
intermediate1 := utils.SelectHash(v.api, isRightAlsoQueried, h1, w1)
intermediate2 := utils.SelectHash(v.api, isJustRightQueried, h0, w1)
rightHash = utils.SelectHash(v.api, isLeftQueried, intermediate1, intermediate2)

// update the pointer to the previous layer query
j = v.api.Select(
isLeftQueried,
v.api.Select(isRightAlsoQueried, jPlusTwo, jPlusOne),
v.api.Select(isJustRightQueried, jPlusOne, j),
)
if leftPresent {
if rightPresent {
j += 2
} else {
j++
}
} else if rightPresent {
j++
}

// update the current layer hashes
hash := v.blake2sChip.HashNode(leftHash[:], rightHash[:], columnValues)
Expand Down
Loading