From be747e027e35f5ceccaffe6ea835e75c1ba59f8c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 5 Jun 2025 18:26:51 -0500 Subject: [PATCH 01/20] refactor: shorten import name --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index d9a7dcfbbb..471c1f3887 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -9,7 +9,7 @@ import ( "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/consensys/gnark-crypto/ecc" - poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" ) @@ -159,9 +159,9 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er } // poseidon2 parameters - gateNamer := newRoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) - rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds - rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds + gateNamer := newRoundGateNamer(bls12377.GetDefaultParameters()) + rF := bls12377.GetDefaultParameters().NbFullRounds + rP := bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 gkrApi = gkrapi.New() @@ -243,9 +243,9 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er return } -var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { - params := poseidon2Bls12377.GetDefaultParameters() - return poseidon2Bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto +var bls12377Permutation = sync.OnceValue(func() *bls12377.Permutation { + params := bls12377.GetDefaultParameters() + return bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto }) // RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver @@ -271,7 +271,7 @@ func registerGatesBls12377() error { y ) - p := poseidon2Bls12377.GetDefaultParameters() + p := bls12377.GetDefaultParameters() halfRf := p.NbFullRounds / 2 gateNames := newRoundGateNamer(p) From 4b43ce33f59ef939396d9f5253f5f7e989c7e06b Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 11:12:18 -0500 Subject: [PATCH 02/20] perf: addition rather than multiplication in gates --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 471c1f3887..a9af651bda 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -20,7 +20,7 @@ func extKeyGate(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[0], 2), x[1], roundKey) + return api.Add(x[0], x[0], x[1], roundKey) } } @@ -71,7 +71,7 @@ func extGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[1], 2), x[0]) + return api.Add(x[1], x[1], x[0]) } // intKeyGate2 applies the internal matrix mul, then adds the round key @@ -80,7 +80,7 @@ func intKeyGate2(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[1], 3), x[0], roundKey) + return api.Add(x[1], x[1], x[1], x[0], roundKey) } } @@ -89,7 +89,7 @@ func intGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[1], 3), x[0]) + return api.Add(x[1], x[1], x[1], x[0]) } // extGate applies the first row of the external matrix @@ -97,7 +97,7 @@ func extGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(api.Mul(x[0], 2), x[1]) + return api.Add(x[0], x[0], x[1]) } // extAddGate applies the first row of the external matrix to the first two elements and adds the third @@ -105,7 +105,7 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 3 { panic("expected 3 inputs") } - return api.Add(api.Mul(x[0], 2), x[1], x[2]) + return api.Add(x[0], x[0], x[1], x[2]) } type GkrPermutations struct { From 6b86526d708e66cc5e3846452f3259a9a41b724f Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 12:57:43 -0500 Subject: [PATCH 03/20] Revert "perf: addition rather than multiplication in gates" This reverts commit 4b43ce33f59ef939396d9f5253f5f7e989c7e06b. --- std/permutation/poseidon2/gkr-poseidon2/gkr.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index a9af651bda..471c1f3887 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -20,7 +20,7 @@ func extKeyGate(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[0], x[0], x[1], roundKey) + return api.Add(api.Mul(x[0], 2), x[1], roundKey) } } @@ -71,7 +71,7 @@ func extGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[1], x[1], x[0]) + return api.Add(api.Mul(x[1], 2), x[0]) } // intKeyGate2 applies the internal matrix mul, then adds the round key @@ -80,7 +80,7 @@ func intKeyGate2(roundKey frontend.Variable) gkr.GateFunction { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[1], x[1], x[1], x[0], roundKey) + return api.Add(api.Mul(x[1], 3), x[0], roundKey) } } @@ -89,7 +89,7 @@ func intGate2(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[1], x[1], x[1], x[0]) + return api.Add(api.Mul(x[1], 3), x[0]) } // extGate applies the first row of the external matrix @@ -97,7 +97,7 @@ func extGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } - return api.Add(x[0], x[0], x[1]) + return api.Add(api.Mul(x[0], 2), x[1]) } // extAddGate applies the first row of the external matrix to the first two elements and adds the third @@ -105,7 +105,7 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 3 { panic("expected 3 inputs") } - return api.Add(x[0], x[0], x[1], x[2]) + return api.Add(api.Mul(x[0], 2), x[1], x[2]) } type GkrPermutations struct { From c4b01b579e908db2f1f22de7a66fc3309fab2b16 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:14:28 -0500 Subject: [PATCH 04/20] feat: generify poseidon2-gkr (not all s-Boxes available yet) --- .../poseidon2/gkr-poseidon2/gkr.go | 114 ++++--- .../poseidon2/gkr-poseidon2/gkr_test.go | 4 +- std/permutation/poseidon2/poseidon2.go | 281 +++++++++++++----- 3 files changed, 253 insertions(+), 146 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 471c1f3887..dcc51a4031 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -1,15 +1,16 @@ package gkr_poseidon2 import ( + "errors" "fmt" - "sync" "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/permutation/poseidon2" "github.com/consensys/gnark-crypto/ecc" - bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" ) @@ -117,18 +118,18 @@ type GkrPermutations struct { // NewGkrPermutations returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. -// Note that the solver will need the function RegisterGkrGates to be called with the desired curves +// Note that the solver will need the function RegisterGates to be called with the desired curves func NewGkrPermutations(api frontend.API) *GkrPermutations { if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { panic("currently only BL12-377 is supported") } - gkrApi, in1, in2, out, err := defineCircuitBls12377() + gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { - panic(fmt.Errorf("failed to define GKR circuit: %v", err)) + panic(fmt.Errorf("failed to define GKR circuit: %w", err)) } return &GkrPermutations{ api: api, - gkrCircuit: gkrApi.Compile(api, "MIMC"), + gkrCircuit: gkrCircuit, in1: in1, in2: in2, out: out, @@ -144,27 +145,28 @@ func (p *GkrPermutations) Compress(a, b frontend.Variable) frontend.Variable { return outs[p.out] } -// defineCircuitBls12377 defines the GKR circuit for the Poseidon2 permutation over BLS12-377 +// defineCircuit defines the GKR circuit for the Poseidon2 permutation over BLS12-377 // insLeft and insRight are the inputs to the permutation // they must be padded to a power of 2 -func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, err error) { +func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out gkr.Variable, err error) { // variable indexes const ( xI = iota yI ) - if err = registerGatesBls12377(); err != nil { + curve := utils.FieldToCurve(api.Compiler().Field()) + p, err := poseidon2.GetDefaultParameters(curve) + if err != nil { return } + gateNamer := newRoundGateNamer(&p, curve) - // poseidon2 parameters - gateNamer := newRoundGateNamer(bls12377.GetDefaultParameters()) - rF := bls12377.GetDefaultParameters().NbFullRounds - rP := bls12377.GetDefaultParameters().NbPartialRounds - halfRf := rF / 2 + if err = registerGates(&p, curve); err != nil { + return + } - gkrApi = gkrapi.New() + gkrApi := gkrapi.New() x := gkrApi.NewInput() y := gkrApi.NewInput() @@ -181,9 +183,17 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er // apply the s-Box to u // the s-Box gates: u¹⁷ = (u⁴)⁴ * u - sBox := func(u gkr.Variable) gkr.Variable { - v := gkrApi.Gate(pow4Gate, u) // u⁴ - return gkrApi.Gate(pow4TimesGate, v, u) // u¹⁷ + + var sBox func(gkr.Variable) gkr.Variable + switch p.DegreeSBox { + case 17: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow4Gate, u) // u⁴ + return gkrApi.Gate(pow4TimesGate, v, u) // u¹⁷ + } + default: + err = fmt.Errorf("unsupported s-Box degree %d", p.DegreeSBox) + return } // apply external matrix multiplication and round key addition @@ -208,89 +218,68 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er // *** construct the circuit *** - for i := range halfRf { + for i := range p.NbFullRounds / 2 { fullRound(i) } { // i = halfRf: first partial round // still using the external matrix, since the linear operation still belongs to a full (canonical) round - x1 := extKeySBox(halfRf, xI, x, y) + x1 := extKeySBox(p.NbFullRounds/2, xI, x, y) x, y = x1, gkrApi.Gate(extGate2, x, y) } - for i := halfRf + 1; i < halfRf+rP; i++ { + for i := p.NbFullRounds/2 + 1; i < p.NbFullRounds/2+p.NbPartialRounds; i++ { x1 := extKeySBox(i, xI, x, y) // the first row of the internal matrix is the same as that of the external matrix x, y = x1, gkrApi.Gate(intGate2, x, y) } { - i := halfRf + rP + i := p.NbFullRounds/2 + p.NbPartialRounds // first iteration of the final batch of full rounds // still using the internal matrix, since the linear operation still belongs to a partial (canonical) round x1 := extKeySBox(i, xI, x, y) x, y = x1, intKeySBox2(i, x, y) } - for i := halfRf + rP + 1; i < rP+rF; i++ { + for i := p.NbFullRounds/2 + p.NbPartialRounds + 1; i < p.NbPartialRounds+p.NbFullRounds; i++ { fullRound(i) } // apply the external matrix one last time to obtain the final value of y - out = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, in2) + out = gkrApi.Gate(extAddGate, y, x, in2) + + gkrCircuit = gkrApi.Compile(api, "MIMC") return } -var bls12377Permutation = sync.OnceValue(func() *bls12377.Permutation { - params := bls12377.GetDefaultParameters() - return bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto -}) - -// RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver -func RegisterGkrGates(curves ...ecc.ID) { +// RegisterGates registers the GKR gates corresponding to the given curves for the solver. +func RegisterGates(curves ...ecc.ID) error { if len(curves) == 0 { - panic("expected at least one curve") + return errors.New("expected at least one curve") } for _, curve := range curves { - switch curve { - case ecc.BLS12_377: - if err := registerGatesBls12377(); err != nil { - panic(err) - } - default: - panic(fmt.Sprintf("curve %s not currently supported", curve)) + p, err := poseidon2.GetDefaultParameters(curve) + if err != nil { + return fmt.Errorf("failed to get default parameters for curve %s: %w", curve, err) + } + if err = registerGates(&p, curve); err != nil { + return fmt.Errorf("failed to register gates for curve %s: %w", curve, err) } } + return nil } -func registerGatesBls12377() error { +func registerGates(p *poseidon2.Parameters, curve ecc.ID) error { const ( x = iota y ) - p := bls12377.GetDefaultParameters() + gateNames := newRoundGateNamer(p, curve) halfRf := p.NbFullRounds / 2 - gateNames := newRoundGateNamer(p) - - if _, err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - - if _, err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } extKeySBox := func(round int, varIndex int) error { _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) @@ -343,15 +332,14 @@ func registerGatesBls12377() error { } } - _, err := gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds)), gkrgates.WithCurves(ecc.BLS12_377)) - return err + return nil } type roundGateNamer string // newRoundGateNamer returns an object that returns standardized names for gates in the GKR circuit -func newRoundGateNamer(p fmt.Stringer) roundGateNamer { - return roundGateNamer(p.String()) +func newRoundGateNamer(p *poseidon2.Parameters, curve ecc.ID) roundGateNamer { + return roundGateNamer(fmt.Sprintf("Poseidon2-%s[t=%d,rF=%d,rP=%d,d=%d]", curve.String(), p.Width, p.NbFullRounds, p.NbPartialRounds, p.DegreeSBox)) } // linear is the name of a gate where a polynomial of total degree 1 is applied to the input diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 601d80cb70..2d9ad7ccb5 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -76,12 +76,12 @@ func TestGkrPermutationCompiles(t *testing.T) { } func BenchmarkGkrPermutations(b *testing.B) { - circuit, assignmment := gkrPermutationsCircuits(b, 50000) + circuit, assignment := gkrPermutationsCircuits(b, 50000) cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) require.NoError(b, err) - witness, err := frontend.NewWitness(&assignmment, ecc.BLS12_377.ScalarField()) + witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) // cpu profile diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 55afe73be5..317cc977dd 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -23,38 +23,157 @@ var ( type Permutation struct { api frontend.API - params parameters + params Parameters } -// parameters describing the poseidon2 implementation -type parameters struct { +// Parameters describing the poseidon2 implementation +type Parameters struct { // len(preimage)+len(digest)=len(preimage)+ceil(log(2*/r)) - width int + Width int // sbox degree - degreeSBox int + DegreeSBox int // number of full rounds (even number) - nbFullRounds int + NbFullRounds int // number of partial rounds - nbPartialRounds int + NbPartialRounds int // round keys: ordered by round then variable - roundKeys [][]big.Int + RoundKeys [][]big.Int +} + +func GetDefaultParameters(curve ecc.ID) (Parameters, error) { + switch curve { // TODO: assumes pairing based builder, reconsider when supporting other backends + case ecc.BN254: + p := poseidonbn254.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbn254.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS12_381: + p := poseidonbls12381.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls12381.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS12_377: + p := poseidonbls12377.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls12377.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BW6_761: + p := poseidonbw6761.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbw6761.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BW6_633: + p := poseidonbw6633.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbw6633.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS24_315: + p := poseidonbls24315.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls24315.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS24_317: + p := poseidonbls24317.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls24317.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + default: + return Parameters{}, fmt.Errorf("curve %s not supported", curve) + } } // NewPoseidon2 returns a new Poseidon2 hasher with default parameters as // defined in the gnark-crypto library. func NewPoseidon2(api frontend.API) (*Permutation, error) { - switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends - case ecc.BLS12_377: - params := poseidonbls12377.GetDefaultParameters() - return NewPoseidon2FromParameters(api, 2, params.NbFullRounds, params.NbPartialRounds) - // TODO: we don't have default parameters for other curves yet. Update this when we do. - default: - return nil, fmt.Errorf("field %s not supported", api.Compiler().Field().String()) + params, err := GetDefaultParameters(utils.FieldToCurve(api.Compiler().Field())) + if err != nil { + return nil, err } + return &Permutation{ + api: api, + params: params, + }, nil } // NewPoseidon2FromParameters returns a new Poseidon2 hasher with the given parameters. @@ -62,76 +181,76 @@ func NewPoseidon2(api frontend.API) (*Permutation, error) { // is deterministic and depends on the curve ID. See the corresponding NewParameters // function in the gnark-crypto library poseidon2 packages for more details. func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartialRounds int) (*Permutation, error) { - params := parameters{width: width, nbFullRounds: nbFullRounds, nbPartialRounds: nbPartialRounds} + params := Parameters{Width: width, NbFullRounds: nbFullRounds, NbPartialRounds: nbPartialRounds} switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends case ecc.BN254: - params.degreeSBox = poseidonbn254.DegreeSBox() + params.DegreeSBox = poseidonbn254.DegreeSBox() concreteParams := poseidonbn254.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS12_381: - params.degreeSBox = poseidonbls12381.DegreeSBox() + params.DegreeSBox = poseidonbls12381.DegreeSBox() concreteParams := poseidonbls12381.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS12_377: - params.degreeSBox = poseidonbls12377.DegreeSBox() + params.DegreeSBox = poseidonbls12377.DegreeSBox() concreteParams := poseidonbls12377.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BW6_761: - params.degreeSBox = poseidonbw6761.DegreeSBox() + params.DegreeSBox = poseidonbw6761.DegreeSBox() concreteParams := poseidonbw6761.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BW6_633: - params.degreeSBox = poseidonbw6633.DegreeSBox() + params.DegreeSBox = poseidonbw6633.DegreeSBox() concreteParams := poseidonbw6633.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS24_315: - params.degreeSBox = poseidonbls24315.DegreeSBox() + params.DegreeSBox = poseidonbls24315.DegreeSBox() concreteParams := poseidonbls24315.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS24_317: - params.degreeSBox = poseidonbls24317.DegreeSBox() + params.DegreeSBox = poseidonbls24317.DegreeSBox() concreteParams := poseidonbls24317.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } default: @@ -143,25 +262,25 @@ func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartial // sBox applies the sBox on buffer[index] func (h *Permutation) sBox(index int, input []frontend.Variable) { tmp := input[index] - if h.params.degreeSBox == 3 { + if h.params.DegreeSBox == 3 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(tmp, input[index]) - } else if h.params.degreeSBox == 5 { + } else if h.params.DegreeSBox == 5 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == 7 { + } else if h.params.DegreeSBox == 7 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == 17 { + } else if h.params.DegreeSBox == 17 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == -1 { + } else if h.params.DegreeSBox == -1 { input[index] = h.api.Inverse(input[index]) } } @@ -204,30 +323,30 @@ func (h *Permutation) matMulM4InPlace(s []frontend.Variable) { // see https://eprint.iacr.org/2023/323.pdf func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { - if h.params.width == 2 { + if h.params.Width == 2 { tmp := h.api.Add(input[0], input[1]) input[0] = h.api.Add(tmp, input[0]) input[1] = h.api.Add(tmp, input[1]) - } else if h.params.width == 3 { + } else if h.params.Width == 3 { tmp := h.api.Add(input[0], input[1]) tmp = h.api.Add(tmp, input[2]) input[0] = h.api.Add(input[0], tmp) input[1] = h.api.Add(input[1], tmp) input[2] = h.api.Add(input[2], tmp) - } else if h.params.width == 4 { + } else if h.params.Width == 4 { h.matMulM4InPlace(input) } else { // at this stage t is supposed to be a multiple of 4 // the MDS matrix is circ(2M4,M4,..,M4) h.matMulM4InPlace(input) tmp := make([]frontend.Variable, 4) - for i := 0; i < h.params.width/4; i++ { + for i := 0; i < h.params.Width/4; i++ { tmp[0] = h.api.Add(tmp[0], input[4*i]) tmp[1] = h.api.Add(tmp[1], input[4*i+1]) tmp[2] = h.api.Add(tmp[2], input[4*i+2]) tmp[3] = h.api.Add(tmp[3], input[4*i+3]) } - for i := 0; i < h.params.width/4; i++ { + for i := 0; i < h.params.Width/4; i++ { input[4*i] = h.api.Add(input[4*i], tmp[0]) input[4*i+1] = h.api.Add(input[4*i], tmp[1]) input[4*i+2] = h.api.Add(input[4*i], tmp[2]) @@ -239,12 +358,12 @@ func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { // when t=2,3 the matrix are respectively [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] // otherwise the matrix is filled with ones except on the diagonal, func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { - if h.params.width == 2 { + if h.params.Width == 2 { sum := h.api.Add(input[0], input[1]) input[0] = h.api.Add(input[0], sum) input[1] = h.api.Mul(2, input[1]) input[1] = h.api.Add(input[1], sum) - } else if h.params.width == 3 { + } else if h.params.Width == 3 { sum := h.api.Add(input[0], input[1]) sum = h.api.Add(sum, input[2]) input[0] = h.api.Add(input[0], sum) @@ -259,10 +378,10 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { // var sum frontend.Variable // sum = input[0] - // for i := 1; i < h.params.width; i++ { + // for i := 1; i < h.params.Width; i++ { // sum = api.Add(sum, input[i]) // } - // for i := 0; i < h.params.width; i++ { + // for i := 0; i < h.params.Width; i++ { // input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) // input[i] = api.Add(input[i], sum) // } @@ -272,40 +391,40 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { // addRoundKeyInPlace adds the round-th key to the buffer func (h *Permutation) addRoundKeyInPlace(round int, input []frontend.Variable) { - for i := 0; i < len(h.params.roundKeys[round]); i++ { - input[i] = h.api.Add(input[i], h.params.roundKeys[round][i]) + for i := 0; i < len(h.params.RoundKeys[round]); i++ { + input[i] = h.api.Add(input[i], h.params.RoundKeys[round][i]) } } // Permutation applies the permutation on input, and stores the result in input. func (h *Permutation) Permutation(input []frontend.Variable) error { - if len(input) != h.params.width { + if len(input) != h.params.Width { return ErrInvalidSizebuffer } // external matrix multiplication, cf https://eprint.iacr.org/2023/323.pdf page 14 (part 6) h.matMulExternalInPlace(input) - rf := h.params.nbFullRounds / 2 + rf := h.params.NbFullRounds / 2 for i := 0; i < rf; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) h.addRoundKeyInPlace(i, input) - for j := 0; j < h.params.width; j++ { + for j := 0; j < h.params.Width; j++ { h.sBox(j, input) } h.matMulExternalInPlace(input) } - for i := rf; i < rf+h.params.nbPartialRounds; i++ { + for i := rf; i < rf+h.params.NbPartialRounds; i++ { // one round = matMulInternal(sBox_sparse(addRoundKey)) h.addRoundKeyInPlace(i, input) h.sBox(0, input) h.matMulInternalInPlace(input) } - for i := rf + h.params.nbPartialRounds; i < h.params.nbFullRounds+h.params.nbPartialRounds; i++ { + for i := rf + h.params.NbPartialRounds; i < h.params.NbFullRounds+h.params.NbPartialRounds; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) h.addRoundKeyInPlace(i, input) - for j := 0; j < h.params.width; j++ { + for j := 0; j < h.params.Width; j++ { h.sBox(j, input) } h.matMulExternalInPlace(input) @@ -321,7 +440,7 @@ func (h *Permutation) Permutation(input []frontend.Variable) error { // Implements the [hash.Compressor] interface for building a Merkle-Damgard // hash construction. func (h *Permutation) Compress(left, right frontend.Variable) frontend.Variable { - if h.params.width != 2 { + if h.params.Width != 2 { panic("poseidon2: Compress can only be used when t=2") } vars := [2]frontend.Variable{left, right} From 5342952398b3c4aa7e45ea3edcc37b3eab85293d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:20:46 -0500 Subject: [PATCH 05/20] fix test --- std/permutation/poseidon2/gkr-poseidon2/gkr_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 2d9ad7ccb5..9808d5f8dc 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" _ "github.com/consensys/gnark/std/hash/all" @@ -19,6 +20,8 @@ func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment tes var k int64 ins := make([][2]frontend.Variable, n) outs := make([]frontend.Variable, n) + params := poseidonbls12377.GetDefaultParameters() + permutation := poseidonbls12377.NewPermutation(params.Width, params.NbFullRounds, params.NbPartialRounds) for i := range n { var x [2]fr.Element ins[i] = [2]frontend.Variable{k, k + 1} @@ -27,7 +30,7 @@ func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment tes x[1].SetInt64(k + 1) y0 := x[1] - require.NoError(t, bls12377Permutation().Permutation(x[:])) + require.NoError(t, permutation.Permutation(x[:])) x[1].Add(&x[1], &y0) outs[i] = x[1] From 9d9b13b0cc833e503fc9abb693e60959f410cd1e Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:26:10 -0500 Subject: [PATCH 06/20] refactor generify tests --- .../poseidon2/gkr-poseidon2/gkr.go | 8 +-- .../poseidon2/gkr-poseidon2/gkr_test.go | 55 ++++++++----------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index dcc51a4031..c87cd450e9 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -119,13 +119,13 @@ type GkrPermutations struct { // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGates to be called with the desired curves -func NewGkrPermutations(api frontend.API) *GkrPermutations { +func NewGkrPermutations(api frontend.API) (*GkrPermutations, error) { if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { - panic("currently only BL12-377 is supported") + return nil, errors.New("currently only BL12-377 is supported") } gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { - panic(fmt.Errorf("failed to define GKR circuit: %w", err)) + return nil, fmt.Errorf("failed to define GKR circuit: %w", err) } return &GkrPermutations{ api: api, @@ -133,7 +133,7 @@ func NewGkrPermutations(api frontend.API) *GkrPermutations { in1: in1, in2: in2, out: out, - } + }, nil } func (p *GkrPermutations) Compress(a, b frontend.Variable) frontend.Variable { diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 9808d5f8dc..b510663959 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -7,62 +7,53 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" _ "github.com/consensys/gnark/std/hash/all" + "github.com/consensys/gnark/std/permutation/poseidon2" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" ) -func gkrPermutationsCircuits(t require.TestingT, n int) (circuit, assignment testGkrPermutationCircuit) { - var k int64 +func gkrPermutationsCircuits(n int) (circuit, assignment testGkrPermutationCircuit) { ins := make([][2]frontend.Variable, n) - outs := make([]frontend.Variable, n) - params := poseidonbls12377.GetDefaultParameters() - permutation := poseidonbls12377.NewPermutation(params.Width, params.NbFullRounds, params.NbPartialRounds) for i := range n { - var x [2]fr.Element - ins[i] = [2]frontend.Variable{k, k + 1} - - x[0].SetInt64(k) - x[1].SetInt64(k + 1) - y0 := x[1] - - require.NoError(t, permutation.Permutation(x[:])) - x[1].Add(&x[1], &y0) - outs[i] = x[1] - - k += 2 + ins[i] = [2]frontend.Variable{i * 2, i*2 + 1} } return testGkrPermutationCircuit{ - Ins: make([][2]frontend.Variable, len(ins)), - Outs: make([]frontend.Variable, len(outs)), + Ins: make([][2]frontend.Variable, len(ins)), }, testGkrPermutationCircuit{ - Ins: ins, - Outs: outs, + Ins: ins, } } func TestGkrCompression(t *testing.T) { - circuit, assignment := gkrPermutationsCircuits(t, 2) + circuit, assignment := gkrPermutationsCircuits(2) test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) } type testGkrPermutationCircuit struct { - Ins [][2]frontend.Variable - Outs []frontend.Variable + Ins [][2]frontend.Variable + skipCheck bool } func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - pos2 := NewGkrPermutations(api) - api.AssertIsEqual(len(c.Ins), len(c.Outs)) + gkr, err := NewGkrPermutations(api) + if err != nil { + return err + } + pos2, err := poseidon2.NewPoseidon2(api) + if err != nil { + return err + } for i := range c.Ins { - api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) + fromGkr := gkr.Compress(c.Ins[i][0], c.Ins[i][1]) + if !c.skipCheck { + api.AssertIsEqual(pos2.Compress(c.Ins[i][0], c.Ins[i][1]), fromGkr) + } } return nil @@ -71,15 +62,15 @@ func (c *testGkrPermutationCircuit) Define(api frontend.API) error { func TestGkrPermutationCompiles(t *testing.T) { // just measure the number of constraints cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrPermutationCircuit{ - Ins: make([][2]frontend.Variable, 52000), - Outs: make([]frontend.Variable, 52000), + Ins: make([][2]frontend.Variable, 52000), + skipCheck: true, }) require.NoError(t, err) fmt.Println(cs.GetNbConstraints(), "constraints") } func BenchmarkGkrPermutations(b *testing.B) { - circuit, assignment := gkrPermutationsCircuits(b, 50000) + circuit, assignment := gkrPermutationsCircuits(50000) cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) require.NoError(b, err) From 430661accfe2fc202f86c249d0de1f08d32b74df Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:36:56 -0500 Subject: [PATCH 07/20] feat: gkrposeidon2 compression for all curves --- .../poseidon2/gkr-poseidon2/gkr.go | 26 ++++++++++++++----- .../poseidon2/gkr-poseidon2/gkr_test.go | 2 +- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index c87cd450e9..fea7608ac6 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -47,6 +47,14 @@ func pow4TimesGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Mul(y, x[1]) } +// pow3Gate computes a -> a³ +func pow3Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + if len(x) != 1 { + panic("expected 1 input") + } + return api.Mul(x[0], x[0], x[0]) +} + // pow2Gate computes a -> a² func pow2Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 1 { @@ -120,9 +128,6 @@ type GkrPermutations struct { // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGates to be called with the desired curves func NewGkrPermutations(api frontend.API) (*GkrPermutations, error) { - if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { - return nil, errors.New("currently only BL12-377 is supported") - } gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { return nil, fmt.Errorf("failed to define GKR circuit: %w", err) @@ -182,10 +187,19 @@ func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out // in every round comes from the previous (canonical) round. // apply the s-Box to u - // the s-Box gates: u¹⁷ = (u⁴)⁴ * u var sBox func(gkr.Variable) gkr.Variable switch p.DegreeSBox { + case 5: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow2Gate, u) // u² + return gkrApi.Gate(pow2TimesGate, v, u) // u⁵ + } + case 7: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow3Gate, u) // u³ + return gkrApi.Gate(pow2TimesGate, v, u) // u⁷ + } case 17: sBox = func(u gkr.Variable) gkr.Variable { v := gkrApi.Gate(pow4Gate, u) // u⁴ @@ -282,12 +296,12 @@ func registerGates(p *poseidon2.Parameters, curve ecc.ID) error { halfRf := p.NbFullRounds / 2 extKeySBox := func(round int, varIndex int) error { - _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) + _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(curve)) return err } intKeySBox2 := func(round int) error { - _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(ecc.BLS12_377)) + _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(curve)) return err } diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index b510663959..592b9abf7c 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -31,7 +31,7 @@ func gkrPermutationsCircuits(n int) (circuit, assignment testGkrPermutationCircu func TestGkrCompression(t *testing.T) { circuit, assignment := gkrPermutationsCircuits(2) - test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) } type testGkrPermutationCircuit struct { From b75843dec8c8546984630e4d8372aad8fb6f94e7 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:44:18 -0500 Subject: [PATCH 08/20] feat gkr-poseidon2 hasher --- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 18 ++++++++++++++++++ std/hash/poseidon2/poseidon2.go | 6 +++--- std/hash/poseidon2/poseidon2_test.go | 16 ++++++++++++---- 3 files changed, 33 insertions(+), 7 deletions(-) create mode 100644 std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go new file mode 100644 index 0000000000..db1566bbc5 --- /dev/null +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -0,0 +1,18 @@ +package gkr_poseidon2 + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + _ "github.com/consensys/gnark/std/hash/all" + gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" +) + +func NewGkrPoseidon2(api frontend.API) (hash.FieldHasher, error) { + f, err := gkr_poseidon2.NewGkrPermutations(api) + if err != nil { + return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) + } + return hash.NewMerkleDamgardHasher(api, f, 0), nil +} diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index 804740ff7c..a5b562fcb2 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -8,9 +8,9 @@ import ( "github.com/consensys/gnark/std/permutation/poseidon2" ) -// NewMerkleDamgardHasher returns a Poseidon2 hasher using the Merkle-Damgard +// NewPoseidon2 returns a Poseidon2 hasher using the Merkle-Damgard // construction with the default parameters. -func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { +func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) @@ -19,5 +19,5 @@ func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.POSEIDON2, NewMerkleDamgardHasher) + hash.Register(hash.POSEIDON2, NewPoseidon2) } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 1ce1d46fef..4a5374258c 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -1,11 +1,13 @@ -package poseidon2 +package poseidon2_test import ( "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/poseidon2" + gkr_poseidon2 "github.com/consensys/gnark/std/hash/poseidon2/gkr-poseidon2" "github.com/consensys/gnark/test" ) @@ -15,12 +17,18 @@ type Poseidon2Circuit struct { } func (c *Poseidon2Circuit) Define(api frontend.API) error { - hsh, err := NewMerkleDamgardHasher(api) + hsh, err := poseidon2.NewPoseidon2(api) + if err != nil { + return err + } + gkr, err := gkr_poseidon2.NewGkrPoseidon2(api) if err != nil { return err } hsh.Write(c.Input...) api.AssertIsEqual(hsh.Sum(), c.Expected) + gkr.Write(c.Input...) + api.AssertIsEqual(gkr.Sum(), c.Expected) return nil } @@ -29,7 +37,7 @@ func TestPoseidon2Hash(t *testing.T) { const nbInputs = 5 // prepare expected output - h := poseidon2.NewMerkleDamgardHasher() + h := poseidonbls12377.NewMerkleDamgardHasher() circInput := make([]frontend.Variable, nbInputs) for i := range nbInputs { _, err := h.Write([]byte{byte(i)}) From be8a07e29dc7019aca7370a9a480d7ec3d6c2ac8 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 8 Jun 2025 14:49:38 -0500 Subject: [PATCH 09/20] fix more renaming --- std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go | 4 ++-- std/hash/poseidon2/poseidon2.go | 6 +++--- std/hash/poseidon2/poseidon2_test.go | 4 ++-- std/internal/mimc/encrypt.go | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go index db1566bbc5..ffc59dee74 100644 --- a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -9,8 +9,8 @@ import ( gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" ) -func NewGkrPoseidon2(api frontend.API) (hash.FieldHasher, error) { - f, err := gkr_poseidon2.NewGkrPermutations(api) +func New(api frontend.API) (hash.FieldHasher, error) { + f, err := gkr_poseidon2.NewGkrCompressor(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) } diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index a5b562fcb2..f53b8716f3 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -8,9 +8,9 @@ import ( "github.com/consensys/gnark/std/permutation/poseidon2" ) -// NewPoseidon2 returns a Poseidon2 hasher using the Merkle-Damgard +// New returns a Poseidon2 hasher using the Merkle-Damgard // construction with the default parameters. -func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.FieldHasher, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) @@ -19,5 +19,5 @@ func NewPoseidon2(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.POSEIDON2, NewPoseidon2) + hash.Register(hash.POSEIDON2, New) } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 4a5374258c..c3998ccc5b 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -17,11 +17,11 @@ type Poseidon2Circuit struct { } func (c *Poseidon2Circuit) Define(api frontend.API) error { - hsh, err := poseidon2.NewPoseidon2(api) + hsh, err := poseidon2.New(api) if err != nil { return err } - gkr, err := gkr_poseidon2.NewGkrPoseidon2(api) + gkr, err := gkr_poseidon2.New(api) if err != nil { return err } diff --git a/std/internal/mimc/encrypt.go b/std/internal/mimc/encrypt.go index 0d45a81506..9c499be976 100644 --- a/std/internal/mimc/encrypt.go +++ b/std/internal/mimc/encrypt.go @@ -106,7 +106,7 @@ func newMimcBW633(api frontend.API) MiMC { } // ------------------------------------------------------------------------------------------------- -// encryptions functions +// encryption functions func pow5(api frontend.API, x frontend.Variable) frontend.Variable { r := api.Mul(x, x) From 1894cfafae5c97778c8eef7eb89637f0a8221049 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 9 Jun 2025 17:49:28 -0500 Subject: [PATCH 10/20] feat: gkrmimc for sbox degree 5 --- std/permutation/gkr-mimc/gkr-mimc.go | 143 ++++++++++++++++++ .../{gkr.go => gkr-poseidon2.go} | 0 .../{gkr_test.go => gkr-poseidon2_test.go} | 0 3 files changed, 143 insertions(+) create mode 100644 std/permutation/gkr-mimc/gkr-mimc.go rename std/permutation/poseidon2/gkr-poseidon2/{gkr.go => gkr-poseidon2.go} (100%) rename std/permutation/poseidon2/gkr-poseidon2/{gkr_test.go => gkr-poseidon2_test.go} (100%) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go new file mode 100644 index 0000000000..f495f2b981 --- /dev/null +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -0,0 +1,143 @@ +package gkr_mimc + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" + "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/std/gkrapi" + "github.com/consensys/gnark/std/gkrapi/gkr" +) + +// mimcCompressor implements a compression function by applying +// the Miyaguchi–Preneel transformation to the MiMC encryption function. +type mimcCompressor struct { + gkrCircuit *gkrapi.Circuit + in0, in1, out gkr.Variable +} + +func newGkrCompressor(api frontend.API) (*mimcCompressor, error) { + gkrApi := gkrapi.New() + + in0 := gkrApi.NewInput() + in1 := gkrApi.NewInput() + + y := in1 + + curve := utils.FieldToCurve(api.Compiler().Field()) + params, _, err := getParams(curve) // params is only used for its length + if err != nil { + return nil, err + } + if err = RegisterGates(curve); err != nil { + return nil, err + } + gateNamer := newGateNamer(curve) + + for i := range len(params) - 1 { + y = gkrApi.NamedGate(gateNamer.round(i), in0, y) + } + + y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) + + return &mimcCompressor{ + gkrCircuit: gkrApi.Compile(api, "poseidon2"), + in0: in0, + in1: in1, + out: y, + }, nil +} + +func RegisterGates(curves ...ecc.ID) error { + for _, curve := range curves { + constants, deg, err := getParams(curve) + if err != nil { + return err + } + gateNamer := newGateNamer(curve) + var lastLayerSBox, nonLastLayerSBox func(*big.Int) gkr.GateFunction + switch deg { + case 5: + lastLayerSBox = addPow5Add + nonLastLayerSBox = addPow5 + default: + return fmt.Errorf("s-Box of degree %d not supported", deg) + } + + for i := range len(constants) - 1 { + if _, err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) + } + } + + if _, err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) + } + } + return nil +} + +// getParams returns the parameters for the MiMC encryption function for the given curve. +// It also returns the degree of the s-Box +func getParams(curve ecc.ID) ([]big.Int, int, error) { + switch curve { + case ecc.BN254: + return bn254.GetConstants(), 5, nil + case ecc.BLS12_381: + return bls12381.GetConstants(), 5, nil + case ecc.BLS12_377: + return bls12377.GetConstants(), 17, nil + case ecc.BLS24_315: + return bls24315.GetConstants(), 5, nil + case ecc.BLS24_317: + return bls24317.GetConstants(), 7, nil + case ecc.BW6_633: + return bw6633.GetConstants(), 5, nil + case ecc.BW6_761: + return bw6761.GetConstants(), 5, nil + default: + return nil, -1, fmt.Errorf("unsupported curve ID: %s", curve) + } +} + +type gateNamer string + +func newGateNamer(o fmt.Stringer) gateNamer { + return gateNamer("MiMC-" + o.String() + "-round-") +} +func (n gateNamer) round(i int) gkr.GateName { + return gkr.GateName(fmt.Sprintf("%s%d", string(n), i)) +} + +func addPow5(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Mul(t, t, s) + } +} + +// addPow5Add: (in[0]+in[1]+key)⁵ + in[0] + in[2] +func addPow5Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Add(api.Mul(t, t, s), in[0], in[2]) + } +} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go similarity index 100% rename from std/permutation/poseidon2/gkr-poseidon2/gkr.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go similarity index 100% rename from std/permutation/poseidon2/gkr-poseidon2/gkr_test.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go From 241d64533f491642261b5945029d4ac128ccf7da Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 9 Jun 2025 22:19:50 -0500 Subject: [PATCH 11/20] mimc length 1 works --- std/hash/mimc/gkr-mimc/gkr-mimc.go | 17 ++++++ std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 66 ++++++++++++++++++++++ std/permutation/gkr-mimc/gkr-mimc.go | 73 ++++++++++++++++++++++++- 3 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 std/hash/mimc/gkr-mimc/gkr-mimc.go create mode 100644 std/hash/mimc/gkr-mimc/gkr-mimc_test.go diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc.go b/std/hash/mimc/gkr-mimc/gkr-mimc.go new file mode 100644 index 0000000000..26be877f41 --- /dev/null +++ b/std/hash/mimc/gkr-mimc/gkr-mimc.go @@ -0,0 +1,17 @@ +package gkr_mimc + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + gkr_mimc "github.com/consensys/gnark/std/permutation/gkr-mimc" +) + +func New(api frontend.API) (hash.FieldHasher, error) { + f, err := gkr_mimc.NewCompressor(api) + if err != nil { + return nil, fmt.Errorf("could not create mimc hasher: %w", err) + } + return hash.NewMerkleDamgardHasher(api, f, 0), nil +} diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go new file mode 100644 index 0000000000..aa558dfd96 --- /dev/null +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -0,0 +1,66 @@ +package gkr_mimc + +import ( + "slices" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/test" +) + +func TestGkrMiMC(t *testing.T) { + lengths := []int{1, 2, 3} + vals := make([]frontend.Variable, len(lengths)*2) + for i := range vals { + vals[i] = i + 1 + } + + for _, length := range lengths[1:2] { + circuit := &testGkrMiMCCircuit{ + In: make([]frontend.Variable, length*2), + } + assignment := &testGkrMiMCCircuit{ + In: slices.Clone(vals[:length*2]), + } + + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254)) + } +} + +type testGkrMiMCCircuit struct { + In []frontend.Variable + skipCheck bool +} + +func (c *testGkrMiMCCircuit) Define(api frontend.API) error { + gkrmimc, err := New(api) + if err != nil { + return err + } + + plainMiMC, err := mimc.New(api) + if err != nil { + return err + } + + // first check that empty input is handled correctly + api.AssertIsEqual(gkrmimc.Sum(), plainMiMC.Sum()) + + ins := [][]frontend.Variable{c.In[:len(c.In)/2], c.In[len(c.In)/2:]} + for _, in := range ins { + gkrmimc.Reset() + gkrmimc.Write(in...) + res := gkrmimc.Sum() + + if !c.skipCheck { + plainMiMC.Reset() + plainMiMC.Write(in...) + expected := plainMiMC.Sum() + api.AssertIsEqual(res, expected) + } + } + + return nil +} diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index f495f2b981..c1d98d67f1 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -17,6 +17,8 @@ import ( "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/hash" + _ "github.com/consensys/gnark/std/hash/all" ) // mimcCompressor implements a compression function by applying @@ -26,7 +28,15 @@ type mimcCompressor struct { in0, in1, out gkr.Variable } -func newGkrCompressor(api frontend.API) (*mimcCompressor, error) { +func (c *mimcCompressor) Compress(x frontend.Variable, y frontend.Variable) frontend.Variable { + res, err := c.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{c.in0: x, c.in1: y}) + if err != nil { + panic(err) + } + return res[c.out] +} + +func NewCompressor(api frontend.API) (hash.Compressor, error) { gkrApi := gkrapi.New() in0 := gkrApi.NewInput() @@ -51,7 +61,7 @@ func newGkrCompressor(api frontend.API) (*mimcCompressor, error) { y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) return &mimcCompressor{ - gkrCircuit: gkrApi.Compile(api, "poseidon2"), + gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), in0: in0, in1: in1, out: y, @@ -70,6 +80,12 @@ func RegisterGates(curves ...ecc.ID) error { case 5: lastLayerSBox = addPow5Add nonLastLayerSBox = addPow5 + case 7: + lastLayerSBox = addPow7Add + nonLastLayerSBox = addPow7 + case 17: + lastLayerSBox = addPow17Add + nonLastLayerSBox = addPow17 default: return fmt.Errorf("s-Box of degree %d not supported", deg) } @@ -141,3 +157,56 @@ func addPow5Add(key *big.Int) gkr.GateFunction { return api.Add(api.Mul(t, t, s), in[0], in[2]) } } + +func addPow7(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Mul(t, t, t, s) // s⁶ × s + } +} + +// addPow7Add: (in[0]+in[1]+key)⁷ + in[0] + in[2] +func addPow7Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Add(api.Mul(t, t, t, s), in[0], in[2]) // s⁶ × s + in[0] + in[2] + } +} + +// addPow17: (in[0]+in[1]+key)¹⁷ +func addPow17(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Mul(t, s) // s¹⁶ × s + } +} + +// addPow17Add: (in[0]+in[1]+key)¹⁷ + in[0] + in[2] +func addPow17Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Add(api.Mul(t, s), in[0], in[2]) // s¹⁶ × s + in[0] + in[2] + } +} From bb3645966028a73f97e9f8a716697b8ffd47b1b6 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 10 Jun 2025 16:20:36 -0500 Subject: [PATCH 12/20] fix final layer --- internal/gkr/bn254/gkr.go | 4 ++-- internal/gkr/engine_hints.go | 2 +- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 17 +++++++++++++++-- std/permutation/gkr-mimc/gkr-mimc.go | 20 +++++++++++--------- test/engine.go | 2 +- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 14269151b3..0174caa564 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 74b15c77ba..8c8bc1b797 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -187,7 +187,7 @@ func (g gateAPI) Println(a ...frontend.Variable) { for i := range a { if s, ok := a[i].(fmt.Stringer); ok { strings[i] = s.String() - } else { + } else if strings[i], ok = a[i].(string); !ok { bigInt := utils.FromInterface(a[i]) strings[i] = bigInt.String() } diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index aa558dfd96..559861af2d 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -1,13 +1,16 @@ package gkr_mimc import ( + "fmt" "slices" "testing" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" ) func TestGkrMiMC(t *testing.T) { @@ -25,7 +28,7 @@ func TestGkrMiMC(t *testing.T) { In: slices.Clone(vals[:length*2]), } - test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254)) + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment)) } } @@ -58,9 +61,19 @@ func (c *testGkrMiMCCircuit) Define(api frontend.API) error { plainMiMC.Reset() plainMiMC.Write(in...) expected := plainMiMC.Sum() - api.AssertIsEqual(res, expected) + api.AssertIsEqual(expected, res) } } return nil } + +func TestGkrMiMCCompiles(t *testing.T) { + const n = 52000 + circuit := testGkrMiMCCircuit{ + In: make([]frontend.Variable, n), + } + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit, frontend.WithCapacity(27_000_000)) + require.NoError(t, err) + fmt.Println(cs.GetNbConstraints(), "constraints") +} diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index c1d98d67f1..df6517f7cf 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -146,7 +146,7 @@ func addPow5(key *big.Int) gkr.GateFunction { } } -// addPow5Add: (in[0]+in[1]+key)⁵ + in[0] + in[2] +// addPow5Add: (in[0]+in[1]+key)⁵ + 2*in[0] + in[2] func addPow5Add(key *big.Int) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { @@ -154,7 +154,9 @@ func addPow5Add(key *big.Int) gkr.GateFunction { } s := api.Add(in[0], in[1], key) t := api.Mul(s, s) - return api.Add(api.Mul(t, t, s), in[0], in[2]) + t = api.Mul(t, t, s) + + return api.Add(t, in[0], in[0], in[2]) } } @@ -169,7 +171,7 @@ func addPow7(key *big.Int) gkr.GateFunction { } } -// addPow7Add: (in[0]+in[1]+key)⁷ + in[0] + in[2] +// addPow7Add: (in[0]+in[1]+key)⁷ + 2*in[0] + in[2] func addPow7Add(key *big.Int) gkr.GateFunction { return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { if len(in) != 3 { @@ -177,7 +179,7 @@ func addPow7Add(key *big.Int) gkr.GateFunction { } s := api.Add(in[0], in[1], key) t := api.Mul(s, s) - return api.Add(api.Mul(t, t, t, s), in[0], in[2]) // s⁶ × s + in[0] + in[2] + return api.Add(api.Mul(t, t, t, s), in[0], in[0], in[2]) // s⁶ × s + 2*in[0] + in[2] } } @@ -203,10 +205,10 @@ func addPow17Add(key *big.Int) gkr.GateFunction { panic("expected three input") } s := api.Add(in[0], in[1], key) - t := api.Mul(s, s) // s² - t = api.Mul(t, t) // s⁴ - t = api.Mul(t, t) // s⁸ - t = api.Mul(t, t) // s¹⁶ - return api.Add(api.Mul(t, s), in[0], in[2]) // s¹⁶ × s + in[0] + in[2] + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Add(api.Mul(t, s), in[0], in[0], in[2]) // s¹⁶ × s + 2*in[0] + in[2] } } diff --git a/test/engine.go b/test/engine.go index 79322af440..aaa63ac7ce 100644 --- a/test/engine.go +++ b/test/engine.go @@ -110,7 +110,7 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v\n%s", r, string(debug.Stack())) + err = fmt.Errorf("%v\n%s", r, debug.Stack()) } }() From 311d9d9368d1fede362e4693cc87b59bebfa521c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 10 Jun 2025 16:22:20 -0500 Subject: [PATCH 13/20] chore generify Println changes --- internal/generator/backend/template/gkr/gkr.go.tmpl | 4 ++-- internal/gkr/bls12-377/gkr.go | 4 ++-- internal/gkr/bls12-381/gkr.go | 4 ++-- internal/gkr/bls24-315/gkr.go | 4 ++-- internal/gkr/bls24-317/gkr.go | 4 ++-- internal/gkr/bw6-633/gkr.go | 4 ++-- internal/gkr/bw6-761/gkr.go | 4 ++-- internal/gkr/small_rational/gkr.go | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 3e3881d15f..16d5eb970b 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -771,13 +771,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index b92ac1249d..b8ef9ea973 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 82084049d9..8f72898737 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index f182c9176b..7aee277ba4 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index a284f14ae9..7c679216fc 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index ec1067f736..2c2bda2037 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index ad5197feef..099b015b02 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index cdf62359f2..d085c6305f 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) From 815adc76b383667b65c74c6f9c59b2eaa94db076 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 10 Jun 2025 17:54:31 -0500 Subject: [PATCH 14/20] feat: use kvstore for caching instances --- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 2 +- std/permutation/gkr-mimc/gkr-mimc.go | 39 ++++++++++++++----- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 31 ++++++++++++--- .../gkr-poseidon2/gkr-poseidon2_test.go | 2 +- 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go index ffc59dee74..88c8baf260 100644 --- a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -10,7 +10,7 @@ import ( ) func New(api frontend.API) (hash.FieldHasher, error) { - f, err := gkr_poseidon2.NewGkrCompressor(api) + f, err := gkr_poseidon2.NewCompressor(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) } diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go index df6517f7cf..266ee00e67 100644 --- a/std/permutation/gkr-mimc/gkr-mimc.go +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -14,6 +14,7 @@ import ( bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" @@ -21,14 +22,14 @@ import ( _ "github.com/consensys/gnark/std/hash/all" ) -// mimcCompressor implements a compression function by applying +// compressor implements a compression function by applying // the Miyaguchi–Preneel transformation to the MiMC encryption function. -type mimcCompressor struct { +type compressor struct { gkrCircuit *gkrapi.Circuit in0, in1, out gkr.Variable } -func (c *mimcCompressor) Compress(x frontend.Variable, y frontend.Variable) frontend.Variable { +func (c *compressor) Compress(x frontend.Variable, y frontend.Variable) frontend.Variable { res, err := c.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{c.in0: x, c.in1: y}) if err != nil { panic(err) @@ -37,6 +38,20 @@ func (c *mimcCompressor) Compress(x frontend.Variable, y frontend.Variable) fron } func NewCompressor(api frontend.API) (hash.Compressor, error) { + + store, ok := api.(kvstore.Store) + if !ok { + return nil, fmt.Errorf("api of type %T does not implement kvstore.Store", api) + } + + cached := store.GetKeyValue(gkrMiMCKey{}) + if cached != nil { + if compressor, ok := cached.(*compressor); ok { + return compressor, nil + } + return nil, fmt.Errorf("cached value is of type %T, not a compressor", cached) + } + gkrApi := gkrapi.New() in0 := gkrApi.NewInput() @@ -60,12 +75,16 @@ func NewCompressor(api frontend.API) (hash.Compressor, error) { y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) - return &mimcCompressor{ - gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), - in0: in0, - in1: in1, - out: y, - }, nil + res := + &compressor{ + gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), + in0: in0, + in1: in1, + out: y, + } + + store.SetKeyValue(gkrMiMCKey{}, res) + return res, nil } func RegisterGates(curves ...ecc.ID) error { @@ -212,3 +231,5 @@ func addPow17Add(key *big.Int) gkr.GateFunction { return api.Add(api.Mul(t, s), in[0], in[0], in[2]) // s¹⁶ × s + 2*in[0] + in[2] } } + +type gkrMiMCKey struct{} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go index 9fbd53246c..4bb0b7a6a7 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -5,9 +5,11 @@ import ( "fmt" "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/internal/kvstore" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/permutation/poseidon2" "github.com/consensys/gnark-crypto/ecc" @@ -117,31 +119,46 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrCompressor struct { +type compressor struct { api frontend.API gkrCircuit *gkrapi.Circuit in1, in2, out gkr.Variable } -// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGates to be called with the desired curves -func NewGkrCompressor(api frontend.API) (*GkrCompressor, error) { +func NewCompressor(api frontend.API) (hash.Compressor, error) { + store, ok := api.(kvstore.Store) + if !ok { + return nil, fmt.Errorf("api of type %T does not implement kvstore.Store", api) + } + + cached := store.GetKeyValue(gkrPoseidon2Key{}) + if cached != nil { + if compressor, ok := cached.(*compressor); ok { + return compressor, nil + } + return nil, fmt.Errorf("cached value is of type %T, not a mimcCompressor", cached) + } + gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { return nil, fmt.Errorf("failed to define GKR circuit: %w", err) } - return &GkrCompressor{ + res := &compressor{ api: api, gkrCircuit: gkrCircuit, in1: in1, in2: in2, out: out, - }, nil + } + store.SetKeyValue(gkrPoseidon2Key{}, res) + return res, nil } -func (p *GkrCompressor) Compress(a, b frontend.Variable) frontend.Variable { +func (p *compressor) Compress(a, b frontend.Variable) frontend.Variable { outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) @@ -365,3 +382,5 @@ func (n roundGateNamer) linear(varIndex, round int) gkr.GateName { func (n roundGateNamer) integrated(varIndex, round int) gkr.GateName { return gkr.GateName(fmt.Sprintf("x%d-i-op-round=%d;%s", varIndex, round, n)) } + +type gkrPoseidon2Key struct{} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index c5214364b8..7c21281d66 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -41,7 +41,7 @@ type testGkrPermutationCircuit struct { func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - gkr, err := NewGkrCompressor(api) + gkr, err := NewCompressor(api) if err != nil { return err } From 7ab8702d5c8ae79c579f6bca65b45fc93ba890be Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 10:34:50 -0500 Subject: [PATCH 15/20] feat: merkledamgard hasher as statestorer --- std/hash/hash.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/std/hash/hash.go b/std/hash/hash.go index c077fd0d37..564f122a48 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -5,6 +5,8 @@ package hash import ( + "fmt" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -110,7 +112,7 @@ type merkleDamgardHasher struct { // NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash // initialState is a value whose preimage is not known -func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) FieldHasher { +func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) StateStorer { return &merkleDamgardHasher{ state: initialState, iv: initialState, @@ -132,3 +134,18 @@ func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { func (h *merkleDamgardHasher) Sum() frontend.Variable { return h.state } + +func (h *merkleDamgardHasher) State() []frontend.Variable { + return []frontend.Variable{h.state} +} + +func (h *merkleDamgardHasher) SetState(state []frontend.Variable) error { + if h.state != h.iv { + return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state") + } + if len(state) != 1 { + return fmt.Errorf("expected one state variable, got %d", len(state)) + } + h.state = state[0] + return nil +} From 2fb7daa16a5dcff8806bc4a569621d734bccf321 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 10:51:26 -0500 Subject: [PATCH 16/20] test: SetState --- std/hash/mimc/gkr-mimc/gkr-mimc.go | 2 +- .../poseidon2/gkr-poseidon2/gkr-poseidon2.go | 2 +- std/hash/poseidon2/poseidon2.go | 6 +- std/hash/poseidon2/poseidon2_test.go | 60 ++++++++++++++++++- 4 files changed, 63 insertions(+), 7 deletions(-) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc.go b/std/hash/mimc/gkr-mimc/gkr-mimc.go index 26be877f41..8e6a8766d8 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc.go @@ -8,7 +8,7 @@ import ( gkr_mimc "github.com/consensys/gnark/std/permutation/gkr-mimc" ) -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { f, err := gkr_mimc.NewCompressor(api) if err != nil { return nil, fmt.Errorf("could not create mimc hasher: %w", err) diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go index 88c8baf260..bbbef1f87c 100644 --- a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -9,7 +9,7 @@ import ( gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" ) -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { f, err := gkr_poseidon2.NewCompressor(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index f53b8716f3..e15c4ca587 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -10,7 +10,7 @@ import ( // New returns a Poseidon2 hasher using the Merkle-Damgard // construction with the default parameters. -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) @@ -19,5 +19,7 @@ func New(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.POSEIDON2, New) + hash.Register(hash.POSEIDON2, func(api frontend.API) (hash.FieldHasher, error) { + return New(api) + }) } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index c3998ccc5b..f6c57736df 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -11,12 +11,12 @@ import ( "github.com/consensys/gnark/test" ) -type Poseidon2Circuit struct { +type poseidon2Circuit struct { Input []frontend.Variable Expected frontend.Variable `gnark:",public"` } -func (c *Poseidon2Circuit) Define(api frontend.API) error { +func (c *poseidon2Circuit) Define(api frontend.API) error { hsh, err := poseidon2.New(api) if err != nil { return err @@ -45,5 +45,59 @@ func TestPoseidon2Hash(t *testing.T) { circInput[i] = i } res := h.Sum(nil) - assert.CheckCircuit(&Poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&Poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 + assert.CheckCircuit(&poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 +} + +func TestStateStorer(t *testing.T) { + assignment := testStateStorerCircuit{ + Input: [][]frontend.Variable{ + {0, 1, 2, 3, 4}, + }, + } + + circuit := testStateStorerCircuit{ + Input: make([][]frontend.Variable, len(assignment.Input)), + } + for i := range assignment.Input { + circuit.Input[i] = make([]frontend.Variable, len(assignment.Input[i])) + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +type testStateStorerCircuit struct { + Input [][]frontend.Variable +} + +func (c *testStateStorerCircuit) Define(api frontend.API) error { + // hashes the whole input in one go + hshFull, err := poseidon2.New(api) + if err != nil { + return err + } + + // hashes the input in two parts + hshPartial, err := poseidon2.New(api) + if err != nil { + return err + } + + for _, input := range c.Input { + // compute desired output + hshFull.Reset() + hshFull.Write(input...) + digest := hshFull.Sum() + + hshPartial.Reset() + hshPartial.Write(input[:len(input)/2]...) + state := hshPartial.State() + hshPartial.Reset() + api.AssertIsEqual(hshPartial.State()[0], 0) + if err = hshPartial.SetState(state); err != nil { + return err + } + hshPartial.Write(input[len(input)/2:]...) + api.AssertIsEqual(hshPartial.Sum(), digest) + } + return nil } From de55a270eb9c81ca37cfefddb0ca52eca2bbe4ce Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 11 Jun 2025 11:01:41 -0500 Subject: [PATCH 17/20] feat: mimc.New to return StateStorer --- std/hash/mimc/mimc.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index 9d8a98e306..db72f54429 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -34,7 +34,7 @@ func NewMiMC(api frontend.API) (MiMC, error) { // NB! See the package documentation for length extension attack consideration. // // [gnark-crypto]: https://pkg.go.dev/github.com/consensys/gnark-crypto/hash -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { h, err := NewMiMC(api) if err != nil { return nil, err @@ -43,5 +43,7 @@ func New(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.MIMC, New) + hash.Register(hash.MIMC, func(api frontend.API) (hash.FieldHasher, error) { + return New(api) + }) } From 8c37435765a54be6acdb4048a78f0b1c388955ba Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 17 Sep 2025 13:27:37 -0500 Subject: [PATCH 18/20] fix: bad merge --- .../poseidon2/gkr-poseidon2/gkr-poseidon2_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go index f3af71983f..b224bf1414 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -76,11 +76,6 @@ func BenchmarkGkrCompressions(b *testing.B) { witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) require.NoError(b, err) - // cpu profile - defer func() { - require.NoError(b, f.Close()) - }() - _, err = cs.Solve(witness) require.NoError(b, err) } From e6a64bc46455e0925013574a3d1109fdfa66377a Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 19:15:13 +0000 Subject: [PATCH 19/20] bench: gkr mimc permutations --- std/hash/mimc/gkr-mimc/gkr-mimc_test.go | 86 +++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go index 559861af2d..8a9381bfd7 100644 --- a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -1,11 +1,15 @@ package gkr_mimc import ( + "errors" "fmt" + "os" "slices" "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" @@ -77,3 +81,85 @@ func TestGkrMiMCCompiles(t *testing.T) { require.NoError(t, err) fmt.Println(cs.GetNbConstraints(), "constraints") } + +type hashTreeCircuit struct { + Leaves []frontend.Variable +} + +func (c hashTreeCircuit) Define(api frontend.API) error { + if len(c.Leaves) == 0 { + return errors.New("no hashing to do") + } + + hsh, err := New(api) + if err != nil { + return err + } + + layer := slices.Clone(c.Leaves) + + for len(layer) > 1 { + if len(layer)%2 == 1 { + layer = append(layer, 0) // pad with zero + } + + for i := range len(layer) / 2 { + hsh.Reset() + hsh.Write(layer[2*i], layer[2*i+1]) + layer[i] = hsh.Sum() + } + + layer = layer[:len(layer)/2] + } + + api.AssertIsDifferent(layer[0], 0) + return nil +} + +func loadCs(t require.TestingT, filename string, circuit frontend.Circuit) constraint.ConstraintSystem { + f, err := os.Open(filename) + + if os.IsNotExist(err) { + // actually compile + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, circuit) + require.NoError(t, err) + f, err = os.Create(filename) + require.NoError(t, err) + defer f.Close() + _, err = cs.WriteTo(f) + require.NoError(t, err) + return cs + } + + defer f.Close() + require.NoError(t, err) + + cs := plonk.NewCS(ecc.BLS12_377) + + _, err = cs.ReadFrom(f) + require.NoError(t, err) + + return cs +} + +func BenchmarkHashTree(b *testing.B) { + const size = 1 << 15 // about 2 ^ 16 total hashes + + circuit := hashTreeCircuit{ + Leaves: make([]frontend.Variable, size), + } + assignment := hashTreeCircuit{ + Leaves: make([]frontend.Variable, size), + } + + for i := range assignment.Leaves { + assignment.Leaves[i] = i + } + + cs := loadCs(b, "gkrmimc_hashtree.cs", &circuit) + + w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + require.NoError(b, cs.IsSolved(w)) +} From d55dbf4238c216b39f3ca19a38c96cce684c6c2f Mon Sep 17 00:00:00 2001 From: Arya Tabaie Date: Wed, 17 Sep 2025 19:18:43 +0000 Subject: [PATCH 20/20] bench: gkr-mimc permutations --- std/permutation/gkr-mimc/gkr-mimc_test.go | 70 +++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 std/permutation/gkr-mimc/gkr-mimc_test.go diff --git a/std/permutation/gkr-mimc/gkr-mimc_test.go b/std/permutation/gkr-mimc/gkr-mimc_test.go new file mode 100644 index 0000000000..93143b1279 --- /dev/null +++ b/std/permutation/gkr-mimc/gkr-mimc_test.go @@ -0,0 +1,70 @@ +package gkr_mimc + +import ( + "errors" + "slices" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/stretchr/testify/require" +) + +type hashTreeCircuit struct { + Leaves []frontend.Variable +} + +func (c hashTreeCircuit) Define(api frontend.API) error { + if len(c.Leaves) == 0 { + return errors.New("no hashing to do") + } + + hsh, err := NewCompressor(api) + if err != nil { + return err + } + + layer := slices.Clone(c.Leaves) + + for len(layer) > 1 { + if len(layer)%2 == 1 { + layer = append(layer, 0) // pad with zero + } + + for i := range len(layer) / 2 { + layer[i] = hsh.Compress(layer[2*i], layer[2*i+1]) + } + + layer = layer[:len(layer)/2] + } + + api.AssertIsDifferent(layer[0], 0) + return nil +} + +func BenchmarkGkrPermutations(b *testing.B) { + circuit, assignment := hashTreeCircuits(50000) + + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) + + witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + _, err = cs.Solve(witness) + require.NoError(b, err) +} + +func hashTreeCircuits(n int) (circuit, assignment hashTreeCircuit) { + leaves := make([]frontend.Variable, n) + for i := range n { + leaves[i] = i + } + + return hashTreeCircuit{ + Leaves: make([]frontend.Variable, len(leaves)), + }, hashTreeCircuit{ + Leaves: leaves, + } +}