diff --git a/constraint/bls12-377/solver.go b/constraint/bls12-377/solver.go index 6ead45ac84..ac7bd11252 100644 --- a/constraint/bls12-377/solver.go +++ b/constraint/bls12-377/solver.go @@ -21,6 +21,7 @@ import ( csolver "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/bls12-377" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/rs/zerolog" @@ -51,14 +52,16 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls12-381/solver.go b/constraint/bls12-381/solver.go index 2f84c1c095..0aa3655def 100644 --- a/constraint/bls12-381/solver.go +++ b/constraint/bls12-381/solver.go @@ -21,6 +21,7 @@ import ( csolver "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/bls12-381" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/rs/zerolog" @@ -51,14 +52,16 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls24-315/solver.go b/constraint/bls24-315/solver.go index a6d5eeda62..063bef050a 100644 --- a/constraint/bls24-315/solver.go +++ b/constraint/bls24-315/solver.go @@ -21,6 +21,7 @@ import ( csolver "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/bls24-315" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/rs/zerolog" @@ -51,14 +52,16 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bls24-317/solver.go b/constraint/bls24-317/solver.go index ddee44f3c5..865aefd962 100644 --- a/constraint/bls24-317/solver.go +++ b/constraint/bls24-317/solver.go @@ -21,6 +21,7 @@ import ( csolver "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/bls24-317" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/rs/zerolog" @@ -51,14 +52,16 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bn254/solver.go b/constraint/bn254/solver.go index fb8c30b93f..9380f8e17e 100644 --- a/constraint/bn254/solver.go +++ b/constraint/bn254/solver.go @@ -21,6 +21,7 @@ import ( csolver "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/bn254" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/rs/zerolog" @@ -51,14 +52,16 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bw6-633/solver.go b/constraint/bw6-633/solver.go index a0c4ce0d43..bc23cd2356 100644 --- a/constraint/bw6-633/solver.go +++ b/constraint/bw6-633/solver.go @@ -21,6 +21,7 @@ import ( csolver "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/bw6-633" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/rs/zerolog" @@ -51,14 +52,16 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/bw6-761/solver.go b/constraint/bw6-761/solver.go index 5eb3ea861e..6bf6d30f30 100644 --- a/constraint/bw6-761/solver.go +++ b/constraint/bw6-761/solver.go @@ -21,6 +21,7 @@ import ( csolver "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/bw6-761" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/rs/zerolog" @@ -51,14 +52,16 @@ type solver struct { func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, error) { // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } // parse options opt, err := csolver.NewConfig(opts...) diff --git a/constraint/solver/gkrgates/registry.go b/constraint/solver/gkrgates/registry.go index 49610a1789..e5d13bbfe9 100644 --- a/constraint/solver/gkrgates/registry.go +++ b/constraint/solver/gkrgates/registry.go @@ -2,11 +2,13 @@ package gkrgates import ( + "errors" "fmt" "reflect" "runtime" "sync" + "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" @@ -35,53 +37,89 @@ type registerSettings struct { curves []ecc.ID } -type registerOption func(*registerSettings) +type RegisterOption func(*registerSettings) error // WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. // RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) registerOption { - return func(settings *registerSettings) { +func WithSolvableVar(solvableVar int) RegisterOption { + return func(settings *registerSettings) error { + if settings.solvableVar != -1 { + return fmt.Errorf("solvable variable already set to %d", settings.solvableVar) + } + if settings.noSolvableVarVerification { + return errors.New("solvable variable already set to NONE") + } settings.solvableVar = solvableVar + return nil } } // WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. // RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) registerOption { - return func(settings *registerSettings) { +func WithUnverifiedSolvableVar(solvableVar int) RegisterOption { + return func(settings *registerSettings) error { + if settings.solvableVar != -1 { + return fmt.Errorf("solvable variable already set to %d", settings.solvableVar) + } + if settings.noSolvableVarVerification { + return errors.New("solvable variable already set to NONE") + } settings.noSolvableVarVerification = true settings.solvableVar = solvableVar + return nil } } // WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. // RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() registerOption { - return func(settings *registerSettings) { +func WithNoSolvableVar() RegisterOption { + return func(settings *registerSettings) error { + if settings.solvableVar != -1 { + return fmt.Errorf("solvable variable already set to %d", settings.solvableVar) + } + if settings.noSolvableVarVerification { + return errors.New("solvable variable already set to NONE") + } settings.solvableVar = -1 settings.noSolvableVarVerification = true + return nil } } // WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) registerOption { - return func(settings *registerSettings) { +func WithUnverifiedDegree(degree int) RegisterOption { + return func(settings *registerSettings) error { + if settings.degree != -1 { + return fmt.Errorf("gate degree already set to %d", settings.degree) + } settings.noDegreeVerification = true settings.degree = degree + return nil } } // WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) registerOption { - return func(settings *registerSettings) { +func WithDegree(degree int) RegisterOption { + return func(settings *registerSettings) error { + if settings.degree != -1 { + return fmt.Errorf("gate degree already set to %d", settings.degree) + } settings.degree = degree + return nil } } // WithName can be used to set a human-readable name for the gate. -func WithName(name gkr.GateName) registerOption { - return func(settings *registerSettings) { +func WithName(name gkr.GateName) RegisterOption { + return func(settings *registerSettings) error { + if name == "" { + return errors.New("gate name must not be empty") + } + if settings.name != "" { + return fmt.Errorf("gate name already set to \"%s\"", settings.name) + } settings.name = name + return nil } } @@ -89,9 +127,13 @@ func WithName(name gkr.GateName) registerOption { // The default is to validate on BN254. // This works for most gates, unless the leading coefficient is divided by // the curve's order, in which case the degree will be computed incorrectly. -func WithCurves(curves ...ecc.ID) registerOption { - return func(settings *registerSettings) { +func WithCurves(curves ...ecc.ID) RegisterOption { + return func(settings *registerSettings) error { + if settings.curves != nil { + return errors.New("gate curves already set") + } settings.curves = curves + return nil } } @@ -99,14 +141,51 @@ func WithCurves(curves ...ecc.ID) registerOption { // - name is a human-readable name for the gate. // - f is the polynomial function defining the gate. // - nbIn is the number of inputs to the gate. -func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { - s := registerSettings{degree: -1, solvableVar: -1, name: GetDefaultGateName(f), curves: []ecc.ID{ecc.BN254}} +// +// If the gate is already registered, it will return false and no error. +func Register(f gkr.GateFunction, nbIn int, options ...RegisterOption) error { + s := registerSettings{degree: -1, solvableVar: -1} for _, option := range options { - option(&s) + if err := option(&s); err != nil { + return err + } + } + if s.name == "" { + s.name = GetDefaultGateName(f) + } + + curvesForTesting := s.curves + allowedCurves := s.curves + if len(curvesForTesting) == 0 { + // no restriction on curves, but only test on BN254 + curvesForTesting = []ecc.ID{ecc.BN254} + allowedCurves = gnark.Curves() + } + + gatesLock.Lock() + defer gatesLock.Unlock() + + if g, ok := gates[s.name]; ok { + // gate already registered + if g.NbIn() != nbIn { + return fmt.Errorf("gate \"%s\" already registered with a different number of inputs (%d != %d)", s.name, g.NbIn(), nbIn) + } + + for _, curve := range curvesForTesting { + gateVer, err := newGateVerifier(curve) + if err != nil { + return err + } + if !gateVer.equal(f, g.Evaluate, nbIn) { + return fmt.Errorf("mismatch with already registered gate \"%s\" (degree %d) over curve %s", s.name, g.Degree(), curve) + } + } + + return nil // gate already registered } - for _, curve := range s.curves { - gateVer, err := NewGateVerifier(curve) + for _, curve := range curvesForTesting { + gateVer, err := newGateVerifier(curve) if err != nil { return err } @@ -116,14 +195,13 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { panic("invalid settings") } const maxAutoDegreeBound = 32 - var err error if s.degree, err = gateVer.findDegree(f, maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", s.name, err) + return fmt.Errorf("for gate \"%s\": %v", s.name, err) } } else { if !s.noDegreeVerification { // check that the given degree is correct - if err = gateVer.verifyDegree(f, s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", s.name, err) + if err = gateVer.verifyGateFunctionDegree(f, s.degree, nbIn); err != nil { + return fmt.Errorf("for gate \"%s\": %v", s.name, err) } } } @@ -135,18 +213,19 @@ func Register(f gkr.GateFunction, nbIn int, options ...registerOption) error { } else { // solvable variable given if !s.noSolvableVarVerification && !gateVer.isVarSolvable(f, s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, s.name) + return fmt.Errorf("cannot verify the solvability of variable %d in gate \"%s\"", s.solvableVar, s.name) } } } - gatesLock.Lock() - defer gatesLock.Unlock() - gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar) + gates[s.name] = gkrtypes.NewGate(f, nbIn, s.degree, s.solvableVar, allowedCurves) return nil } +// Get returns a registered gate of the given name. +// If not found, it will panic. +// Gates can be added to the registry through Register. func Get(name gkr.GateName) *gkrtypes.Gate { gatesLock.Lock() defer gatesLock.Unlock() @@ -156,13 +235,16 @@ func Get(name gkr.GateName) *gkrtypes.Gate { panic(fmt.Sprintf("gate \"%s\" not found", name)) } +// gateVerifier handles finding/verifying of gate degrees . +// Some of the work is done on a per-curve basis. type gateVerifier struct { - isAdditive func(f gkr.GateFunction, i int, nbIn int) bool - findDegree func(f gkr.GateFunction, max, nbIn int) (int, error) - verifyDegree func(f gkr.GateFunction, claimedDegree, nbIn int) error + isAdditive func(f gkr.GateFunction, i int, nbIn int) bool + findDegree func(f gkr.GateFunction, max, nbIn int) (int, error) + verifyGateFunctionDegree func(f gkr.GateFunction, claimedDegree, nbIn int) error + equal func(f1, f2 gkr.GateFunction, nbIn int) bool } -func NewGateVerifier(curve ecc.ID) (*gateVerifier, error) { +func newGateVerifier(curve ecc.ID) (*gateVerifier, error) { var ( o gateVerifier err error @@ -171,31 +253,38 @@ func NewGateVerifier(curve ecc.ID) (*gateVerifier, error) { case ecc.BLS12_377: o.isAdditive = bls12377.IsGateFunctionAdditive o.findDegree = bls12377.FindGateFunctionDegree - o.verifyDegree = bls12377.VerifyGateFunctionDegree + o.verifyGateFunctionDegree = bls12377.VerifyGateFunctionDegree + o.equal = bls12377.EqualGateFunction case ecc.BLS12_381: o.isAdditive = bls12381.IsGateFunctionAdditive o.findDegree = bls12381.FindGateFunctionDegree - o.verifyDegree = bls12381.VerifyGateFunctionDegree + o.verifyGateFunctionDegree = bls12381.VerifyGateFunctionDegree + o.equal = bls12381.EqualGateFunction case ecc.BLS24_315: o.isAdditive = bls24315.IsGateFunctionAdditive o.findDegree = bls24315.FindGateFunctionDegree - o.verifyDegree = bls24315.VerifyGateFunctionDegree + o.verifyGateFunctionDegree = bls24315.VerifyGateFunctionDegree + o.equal = bls24315.EqualGateFunction case ecc.BLS24_317: o.isAdditive = bls24317.IsGateFunctionAdditive o.findDegree = bls24317.FindGateFunctionDegree - o.verifyDegree = bls24317.VerifyGateFunctionDegree + o.verifyGateFunctionDegree = bls24317.VerifyGateFunctionDegree + o.equal = bls24317.EqualGateFunction case ecc.BN254: o.isAdditive = bn254.IsGateFunctionAdditive o.findDegree = bn254.FindGateFunctionDegree - o.verifyDegree = bn254.VerifyGateFunctionDegree + o.verifyGateFunctionDegree = bn254.VerifyGateFunctionDegree + o.equal = bn254.EqualGateFunction case ecc.BW6_633: o.isAdditive = bw6633.IsGateFunctionAdditive o.findDegree = bw6633.FindGateFunctionDegree - o.verifyDegree = bw6633.VerifyGateFunctionDegree + o.verifyGateFunctionDegree = bw6633.VerifyGateFunctionDegree + o.equal = bw6633.EqualGateFunction case ecc.BW6_761: o.isAdditive = bw6761.IsGateFunctionAdditive o.findDegree = bw6761.FindGateFunctionDegree - o.verifyDegree = bw6761.VerifyGateFunctionDegree + o.verifyGateFunctionDegree = bw6761.VerifyGateFunctionDegree + o.equal = bw6761.EqualGateFunction default: err = fmt.Errorf("unsupported curve %s", curve) } @@ -209,7 +298,7 @@ func GetDefaultGateName(fn gkr.GateFunction) gkr.GateName { return gkr.GateName(runtime.FuncForPC(fnptr).Name()) } -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// findSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. // It returns -1 if it fails to find one. // nbIn is the number of inputs to the gate func (v *gateVerifier) findSolvableVar(f gkr.GateFunction, nbIn int) int { @@ -221,15 +310,17 @@ func (v *gateVerifier) findSolvableVar(f gkr.GateFunction, nbIn int) int { return -1 } -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// isVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. // It returns false if it fails to verify this claim. // nbIn is the number of inputs to the gate. func (v *gateVerifier) isVarSolvable(f gkr.GateFunction, claimedSolvableVar, nbIn int) bool { return v.isAdditive(f, claimedSolvableVar, nbIn) } -func (v *gateVerifier) VerifyDegree(g *gkrtypes.Gate) error { - if err := v.verifyDegree(g.Evaluate, g.Degree(), g.NbIn()); err != nil { +// verifyDegree checks that the declared total degree of the gate polynomial +// is correct. +func (v *gateVerifier) verifyDegree(g *gkrtypes.Gate) error { + if err := v.verifyGateFunctionDegree(g.Evaluate, g.Degree(), g.NbIn()); err != nil { deg, errFind := v.findDegree(g.Evaluate, g.Degree(), g.NbIn()) if errFind != nil { return fmt.Errorf("could not find gate degree: %w\n\tdegree verification error: %w", errFind, errFind) @@ -239,7 +330,9 @@ func (v *gateVerifier) VerifyDegree(g *gkrtypes.Gate) error { return nil } -func (v *gateVerifier) VerifySolvability(g *gkrtypes.Gate) error { +// verifySolvability checks that the variable declared as "solvable" +// in fact occurs with degree exactly 1. +func (v *gateVerifier) verifySolvability(g *gkrtypes.Gate) error { if g.SolvableVar() == -1 { return nil } diff --git a/constraint/solver/gkrgates/registry_test.go b/constraint/solver/gkrgates/registry_test.go index ec41888ef3..fbca20ffba 100644 --- a/constraint/solver/gkrgates/registry_test.go +++ b/constraint/solver/gkrgates/registry_test.go @@ -1,7 +1,6 @@ package gkrgates import ( - "fmt" "testing" "github.com/consensys/gnark/frontend" @@ -11,20 +10,37 @@ import ( "github.com/stretchr/testify/assert" ) -func TestRegisterDegreeDetection(t *testing.T) { +func TestRegister(t *testing.T) { testGate := func(name gkr.GateName, f gkr.GateFunction, nbIn, degree int) { t.Run(string(name), func(t *testing.T) { name = name + "-register-gate-test" - assert.NoError(t, Register(f, nbIn, WithDegree(degree), WithName(name)), "given degree must be accepted") + assert.NoError(t, + Register(f, nbIn, WithDegree(degree), WithName(name+"_given")), + "given degree must be accepted", + ) - assert.Error(t, Register(f, nbIn, WithDegree(degree-1), WithName(name)), "lower degree must be rejected") + assert.Error(t, + Register(f, nbIn, WithDegree(degree-1), WithName(name+"_lower")), + "error must be returned for lower degree", + ) - assert.Error(t, Register(f, nbIn, WithDegree(degree+1), WithName(name)), "higher degree must be rejected") + assert.Error(t, + Register(f, nbIn, WithDegree(degree+1), WithName(name+"_higher")), + "error must be returned for higher degree", + ) - assert.NoError(t, Register(f, nbIn), "no degree must be accepted") + assert.NoError(t, + Register(f, nbIn, WithName(name+"_no_degree")), + "no error must be returned when no degree is specified", + ) - assert.Equal(t, degree, Get(name).Degree(), "degree must be detected correctly") + assert.Equal(t, degree, Get(name+"_no_degree").Degree(), "degree must be detected correctly") + + err := Register(func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + return api.Add(f(api, x...), 1) + }, nbIn, WithDegree(degree), WithName(name+"_given")) + assert.Error(t, err, "registering another function under the same name must fail") }) } @@ -47,15 +63,22 @@ func TestRegisterDegreeDetection(t *testing.T) { ) }, 2, 1) - // zero polynomial must not be accepted t.Run("zero", func(t *testing.T) { const gateName gkr.GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, gkrtypes.ErrZeroFunction) zeroGate := func(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Sub(x[0], x[0]) } - assert.Equal(t, expectedError, Register(zeroGate, 1, WithName(gateName))) - assert.Equal(t, expectedError, Register(zeroGate, 1, WithName(gateName), WithDegree(2))) + // Attempt to register the zero gate without specifying a degree + assert.Error(t, + Register(zeroGate, 1, WithName(gateName)), + "error must be returned for zero polynomial", + ) + + // Attempt to register the zero gate with a specified degree + assert.Error(t, + Register(zeroGate, 1, WithName(gateName), WithDegree(2)), + "error must be returned for zero polynomial with degree", + ) }) } diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index d07f64a827..7305d7e00d 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -299,7 +299,6 @@ func generateGkrBackend(cfg gkrConfig) error { // gkr backend entries := []bavard.Entry{ {File: filepath.Join(packageDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, - {File: filepath.Join(packageDir, "gate_testing.go"), Templates: []string{"gate_testing.go.tmpl"}}, {File: filepath.Join(packageDir, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, {File: filepath.Join(packageDir, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl", "sumcheck.test.defs.go.tmpl"}}, {File: filepath.Join(packageDir, testVectorUtilsFileName), Templates: []string{"test_vector_utils.go.tmpl"}}, @@ -317,9 +316,10 @@ func generateGkrBackend(cfg gkrConfig) error { {File: filepath.Join(packageDir, "sumcheck_test_vector_gen.go"), Templates: []string{"sumcheck.test.vectors.gen.go.tmpl", "sumcheck.test.defs.go.tmpl"}}, }...) } else { - entries = append(entries, bavard.Entry{ - File: filepath.Join(packageDir, "solver_hints.go"), Templates: []string{"solver_hints.go.tmpl"}, - }) + entries = append(entries, []bavard.Entry{ + {File: filepath.Join(packageDir, "solver_hints.go"), Templates: []string{"solver_hints.go.tmpl"}}, + {File: filepath.Join(packageDir, "gate_testing.go"), Templates: []string{"gate_testing.go.tmpl"}}, + }...) } if err := bgen.Generate(cfg, "gkr", "./template/gkr/", entries...); err != nil { diff --git a/internal/generator/backend/template/gkr/gate_testing.go.tmpl b/internal/generator/backend/template/gkr/gate_testing.go.tmpl index 8c78af347c..a782015cfa 100644 --- a/internal/generator/backend/template/gkr/gate_testing.go.tmpl +++ b/internal/generator/backend/template/gkr/gate_testing.go.tmpl @@ -155,6 +155,17 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error return nil } +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make({{.FieldPackageName}}.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} + {{- if not .CanUseFFT }} // interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) // Note that the runtime is O(len(X)³) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 5105b0a33d..3e3881d15f 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -735,7 +735,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod {{ .ElementType }} - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index e1d41e8cb8..80a908c720 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -2,8 +2,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -14,113 +14,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -128,7 +149,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_{{.FieldID}}") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -136,4 +157,14 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { return proof.SerializeToBigInts(outs) } +} + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +{{ print "// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}}"}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } } \ No newline at end of file diff --git a/internal/generator/backend/template/representations/solver.go.tmpl b/internal/generator/backend/template/representations/solver.go.tmpl index 1b8dd42fe4..5f6fa68c70 100644 --- a/internal/generator/backend/template/representations/solver.go.tmpl +++ b/internal/generator/backend/template/representations/solver.go.tmpl @@ -12,8 +12,11 @@ import ( "github.com/rs/zerolog" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/field/pool" + {{- if not .NoGKR }} "github.com/consensys/gnark/constraint/solver/gkrgates" gkr "github.com/consensys/gnark/internal/gkr/{{ toLower .Curve }}" + "github.com/consensys/gnark/internal/gkr/gkrhints" + {{- end }} "github.com/consensys/gnark/internal/gkr/gkrtypes" {{ template "import_fr" . }} ) @@ -43,14 +46,16 @@ func newSolver(cs *system, witness fr.Vector, opts ...csolver.Option) (*solver, {{ if not .NoGKR -}} // add GKR options to overwrite the placeholder if cs.GkrInfo.Is() { - var gkrData gkr.SolvingData solvingInfo, err := gkrtypes.StoringToSolvingInfo(cs.GkrInfo, gkrgates.Get) if err != nil { return nil, err } + gkrData := gkr.NewSolvingData(solvingInfo) + var gkrHints *gkrhints.TestEngineHints opts = append(opts, - csolver.OverrideHint(cs.GkrInfo.SolveHintID, gkr.SolveHint(solvingInfo, &gkrData)), - csolver.OverrideHint(cs.GkrInfo.ProveHintID, gkr.ProveHint(cs.GkrInfo.HashName, &gkrData))) + csolver.OverrideHint(csolver.GetHintID(gkrHints.GetAssignment), gkr.GetAssignmentHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Solve), gkr.SolveHint(gkrData)), + csolver.OverrideHint(csolver.GetHintID(gkrHints.Prove), gkr.ProveHint(cs.GkrInfo.HashName, gkrData))) } {{ end -}} diff --git a/internal/gkr/bls12-377/gate_testing.go b/internal/gkr/bls12-377/gate_testing.go index 415a5ff5b3..9e5a3868f3 100644 --- a/internal/gkr/bls12-377/gate_testing.go +++ b/internal/gkr/bls12-377/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index f5dfad020e..b92ac1249d 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 39547cff29..7b3d086213 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,113 +21,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +156,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS12_377") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +165,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls12-381/gate_testing.go b/internal/gkr/bls12-381/gate_testing.go index ef7694dc18..5b281fd634 100644 --- a/internal/gkr/bls12-381/gate_testing.go +++ b/internal/gkr/bls12-381/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index f5617a59d4..82084049d9 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index cb498c78b7..d6c1dae79b 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,113 +21,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +156,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS12_381") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +165,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls24-315/gate_testing.go b/internal/gkr/bls24-315/gate_testing.go index 1682d24771..058b53cc06 100644 --- a/internal/gkr/bls24-315/gate_testing.go +++ b/internal/gkr/bls24-315/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7d89baf7ef..f182c9176b 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 914c8a9d61..1c2b9b235e 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,113 +21,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +156,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS24_315") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +165,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bls24-317/gate_testing.go b/internal/gkr/bls24-317/gate_testing.go index 1bffab29e3..ed418ff1b0 100644 --- a/internal/gkr/bls24-317/gate_testing.go +++ b/internal/gkr/bls24-317/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index fc9908b918..a284f14ae9 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index f6e1ad993d..c844ba6452 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,113 +21,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +156,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BLS24_317") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +165,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bn254/gate_testing.go b/internal/gkr/bn254/gate_testing.go index 716ba3891b..e9311a3ea5 100644 --- a/internal/gkr/bn254/gate_testing.go +++ b/internal/gkr/bn254/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 04cf3512af..14269151b3 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 7bc3782932..61d5bc7ed3 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,113 +21,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +156,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BN254") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +165,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bw6-633/gate_testing.go b/internal/gkr/bw6-633/gate_testing.go index 0fafa45a0d..8074b9621c 100644 --- a/internal/gkr/bw6-633/gate_testing.go +++ b/internal/gkr/bw6-633/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index cc1245e726..ec1067f736 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 57343d291f..2f0254c237 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,113 +21,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +156,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BW6_633") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +165,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/bw6-761/gate_testing.go b/internal/gkr/bw6-761/gate_testing.go index 6eda2ebe73..0bae6258dc 100644 --- a/internal/gkr/bw6-761/gate_testing.go +++ b/internal/gkr/bw6-761/gate_testing.go @@ -142,3 +142,14 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error } return nil } + +// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point. +func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool { + x := make(fr.Vector, nbIn) + x.MustSetRandom() + fFr := api.convertFunc(f) + gFr := api.convertFunc(g) + fAt := fFr(x...) + gAt := gFr(x...) + return fAt.Equal(gAt) +} diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index f90f28114b..ad5197feef 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod fr.Element - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 606f13ec23..1f47c9a578 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -9,8 +9,8 @@ import ( "fmt" "math/big" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark-crypto/utils" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -21,113 +21,134 @@ import ( ) type SolvingData struct { - assignment WireAssignment - circuit gkrtypes.Circuit - workers *utils.WorkerPool + assignment WireAssignment // assignment is indexed wire-first, instance-second. The number of instances is padded to a power of 2. + circuit gkrtypes.Circuit + maxNbIn int // maximum number of inputs for a gate in the circuit + nbInstances int } -func (d *SolvingData) init(info gkrtypes.SolvingInfo) { - d.workers = utils.NewWorkerPool() - d.circuit = info.Circuit - d.circuit.SetNbUniqueOutputs() +type newSolvingDataSettings struct { + assignment gkrtypes.WireAssignment +} - d.assignment = make(WireAssignment, len(d.circuit)) - for i := range d.assignment { - d.assignment[i] = make([]fr.Element, info.NbInstances) +type NewSolvingDataOption func(*newSolvingDataSettings) + +// WithAssignment re-use already computed wire assignments. +func WithAssignment(assignment gkrtypes.WireAssignment) NewSolvingDataOption { + return func(s *newSolvingDataSettings) { + s.assignment = assignment } } -// this module assumes that wire and instance indexes respect dependencies +// NewSolvingData converts gkrtypes.SolvingInfo into a concrete SolvingData object: +// - The gates are loaded in accordance with their names. +// - The instances/assignments are padded into a power of 2, suitable for the multilinear extensions used +// in the GKR prover. +func NewSolvingData(info gkrtypes.SolvingInfo, options ...NewSolvingDataOption) *SolvingData { + var s newSolvingDataSettings + for _, opt := range options { + opt(&s) + } -func setOuts(a WireAssignment, circuit gkrtypes.Circuit, outs []*big.Int) { - outsI := 0 - for i := range circuit { - if circuit[i].IsOutput() { - for j := range a[i] { - a[i][j].BigInt(outs[outsI]) - outsI++ - } - } + d := SolvingData{ + circuit: info.Circuit, + assignment: make(WireAssignment, len(info.Circuit)), + nbInstances: info.NbInstances, } - // Check if outsI == len(outs)? -} -func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { - return func(_ *big.Int, ins, outs []*big.Int) error { - // assumes assignmentVector is arranged wire first, instance second in order of solution - offsets := info.AssignmentOffsets() - data.init(info) - maxNIn := data.circuit.MaxGateNbIn() - - chunks := info.Chunks() - - solveTask := func(chunkOffset int) utils.Task { - return func(startInChunk, endInChunk int) { - start := startInChunk + chunkOffset - end := endInChunk + chunkOffset - inputs := make([]frontend.Variable, maxNIn) - dependencyHeads := make([]int, len(data.circuit)) // for each wire, which of its dependencies we would look at next - for wI := range data.circuit { // skip instances that are not relevant (related to instances before the current task) - deps := info.Dependencies[wI] - dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { - return deps[i].InputInstance - }, len(deps), start) - } + d.maxNbIn = d.circuit.MaxGateNbIn() - for instanceI := start; instanceI < end; instanceI++ { - for wireI := range data.circuit { - wire := &data.circuit[wireI] - deps := info.Dependencies[wireI] - if wire.IsInput() { - if dependencyHeads[wireI] < len(deps) && instanceI == deps[dependencyHeads[wireI]].InputInstance { - dep := deps[dependencyHeads[wireI]] - data.assignment[wireI][instanceI].Set(&data.assignment[dep.OutputWire][dep.OutputInstance]) - dependencyHeads[wireI]++ - } else { - data.assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) - } - } else { - // assemble the inputs - inputIndexes := info.Circuit[wireI].Inputs - for i, inputI := range inputIndexes { - inputs[i] = &data.assignment[inputI][instanceI] - } - gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) - } - } + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(info.NbInstances))) + for i := range d.assignment { + d.assignment[i] = make([]fr.Element, nbPaddedInstances) + } + + if s.assignment != nil { + if len(s.assignment) != len(d.assignment) { + panic(fmt.Sprintf("provided assignment has %d wires, expected %d", len(s.assignment), len(d.assignment))) + } + for i := range d.assignment { + if len(s.assignment[i]) != info.NbInstances { + panic(fmt.Sprintf("provided assignment for wire %d has %d instances, expected %d", i, len(s.assignment[i]), info.NbInstances)) + } + for j := range s.assignment[i] { + if _, err := d.assignment[i][j].SetInterface(s.assignment[i][j]); err != nil { + panic(fmt.Sprintf("provided assignment for wire %d instance %d is not a valid field element: %v", i, j, err)) } } + // inline equivalent of repeatUntilEnd + for j := len(s.assignment[i]); j < nbPaddedInstances; j++ { + d.assignment[i][j] = d.assignment[i][j-1] // pad with the last value + } } + } + + return &d +} - start := 0 - for _, end := range chunks { - data.workers.Submit(end-start, solveTask(start), 1024).Wait() - start = end +// this module assumes that wire and instance indexes respect dependencies + +// GetAssignmentHint generates a hint that returns the value of a particular wire at a particular instance. +// It is intended for use in the debugging function gkrapi.API.GetValue. +func GetAssignmentHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + if len(ins) != 3 { + return fmt.Errorf("GetAssignmentHint expects 3 inputs: wire index, instance index, and dummy dependency enforcer") } + wireI := ins[0].Uint64() + instanceI := ins[1].Uint64() + + data.assignment[wireI][instanceI].BigInt(outs[0]) + + return nil + } +} - for _, p := range info.Prints { - serializable := make([]any, len(p.Values)) - for i, v := range p.Values { - if p.IsGkrVar[i] { // serializer stores uint32 in slices as uint64 - serializable[i] = data.assignment[algo_utils.ForceUint32(v)][p.Instance].String() - } else { - serializable[i] = v +// SolveHint generate a hint that computes the assignments for all wires in a circuit instance. +// It is intended for use in gkrapi.API.AddInstance. +func SolveHint(data *SolvingData) hint.Hint { + return func(_ *big.Int, ins, outs []*big.Int) error { + instanceI := ins[0].Uint64() + + gateIns := make([]frontend.Variable, data.maxNbIn) + outsI := 0 + insI := 1 // skip the first input, which is the instance index + for wI := range data.circuit { + w := &data.circuit[wI] + if w.IsInput() { // read from provided input + data.assignment[wI][instanceI].SetBigInt(ins[insI]) + insI++ + } else { + + // assemble input for gate + for i, inWI := range w.Inputs { + gateIns[i] = &data.assignment[inWI][instanceI] } + + data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element)) + } + if w.IsOutput() { + data.assignment[wI][instanceI].BigInt(outs[outsI]) + outsI++ } - fmt.Println(serializable...) } - setOuts(data.assignment, info.Circuit, outs) - return nil } } +// ProveHint generates a hint that produces the GKR proof using the computed assignments contained in data. +// It is meant for use in gkrapi.Circuit.finalize. func ProveHint(hashName string, data *SolvingData) hint.Hint { return func(_ *big.Int, ins, outs []*big.Int) error { - insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called + + data.assignment.repeatUntilEnd(data.nbInstances) + + // The first input is dummy, just to ensure the solver's work is done before the prover is called. + // The rest constitute the initial fiat shamir challenge + insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { + b := make([]byte, fr.Bytes) i.FillBytes(b) return b[:] @@ -135,7 +156,7 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { hsh := hash.NewHash(hashName + "_BW6_761") - proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...), WithWorkers(data.workers)) + proof, err := Prove(data.circuit, data.assignment, fiatshamir.WithHash(hsh, insBytes...)) if err != nil { return err } @@ -144,3 +165,13 @@ func ProveHint(hashName string, data *SolvingData) hint.Hint { } } + +// repeatUntilEnd for each wire, sets all the values starting from n > 0 to its predecessor. +// e.g. {{1, 2, 3}, {4, 5, 6}}.repeatUntilEnd(2) -> {{1, 2, 2}, {4, 5, 5}} +func (a WireAssignment) repeatUntilEnd(n int) { + for i := range a { + for j := n; j < len(a[i]); j++ { + a[i][j] = a[i][j-1] + } + } +} diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 955ad8a354..7cfe315a32 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -341,7 +341,10 @@ func (p Proof) Serialize() []frontend.Variable { return res } -func computeLogNbInstances(wires []*gkrtypes.Wire, serializedProofLen int) int { +// ComputeLogNbInstances derives n such that the number of instances is 2ⁿ +// from the size of the proof and the circuit structure. +// It is used in proof deserialization. +func ComputeLogNbInstances(wires []*gkrtypes.Wire, serializedProofLen int) int { partialEvalElemsPerVar := 0 for _, w := range wires { if !w.NoProof() { @@ -366,7 +369,7 @@ func (r *variablesReader) hasNextN(n int) bool { func DeserializeProof(sorted []*gkrtypes.Wire, serializedProof []frontend.Variable) (Proof, error) { proof := make(Proof, len(sorted)) - logNbInstances := computeLogNbInstances(sorted, len(serializedProof)) + logNbInstances := ComputeLogNbInstances(sorted, len(serializedProof)) reader := variablesReader(serializedProof) for i, wI := range sorted { diff --git a/internal/gkr/gkr_test.go b/internal/gkr/gkr_test.go index faf8eadc95..02c6d6cac2 100644 --- a/internal/gkr/gkr_test.go +++ b/internal/gkr/gkr_test.go @@ -245,7 +245,7 @@ func TestLogNbInstances(t *testing.T) { assert.NoError(t, err) wires := testCase.Circuit.TopologicalSort() serializedProof := testCase.Proof.Serialize() - logNbInstances := computeLogNbInstances(wires, len(serializedProof)) + logNbInstances := ComputeLogNbInstances(wires, len(serializedProof)) assert.Equal(t, 1, logNbInstances) } } diff --git a/internal/gkr/gkrhints/engine_hints.go b/internal/gkr/gkrhints/engine_hints.go new file mode 100644 index 0000000000..3b74cdf4a0 --- /dev/null +++ b/internal/gkr/gkrhints/engine_hints.go @@ -0,0 +1,194 @@ +package gkrhints + +import ( + "errors" + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/frontend" + bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" + bls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" + bls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" + bls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" + bn254 "github.com/consensys/gnark/internal/gkr/bn254" + bw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" + bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" + "github.com/consensys/gnark/internal/gkr/gkrinfo" + "github.com/consensys/gnark/internal/gkr/gkrtypes" + "github.com/consensys/gnark/internal/utils" +) + +type TestEngineHints struct { + assignment gkrtypes.WireAssignment + info *gkrinfo.StoringInfo // we retain a reference to the solving info to allow the caller to modify it between calls to Solve and Prove + circuit gkrtypes.Circuit + gateIns []frontend.Variable +} + +func NewTestEngineHints(info *gkrinfo.StoringInfo) (*TestEngineHints, error) { + circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) + if err != nil { + return nil, err + } + + return &TestEngineHints{ + info: info, + circuit: circuit, + gateIns: make([]frontend.Variable, circuit.MaxGateNbIn()), + assignment: make(gkrtypes.WireAssignment, len(circuit)), + }, + err +} + +// Solve solves one instance of a GKR circuit. +// The first input is the index of the instance. The rest are the inputs of the circuit, in their nominal order. +func (h *TestEngineHints) Solve(mod *big.Int, ins []*big.Int, outs []*big.Int) error { + + instanceI := len(h.assignment[0]) + if in0 := ins[0].Uint64(); !ins[0].IsUint64() || in0 > 0xffffffff { + return errors.New("first input must be a uint32 instance index") + } else if in0 != uint64(instanceI) || h.info.NbInstances != instanceI { + return errors.New("first input must equal the number of instances, and calls to Solve must be done in order of instance index") + } + + api := gateAPI{mod} + + inI := 1 + outI := 0 + for wI := range h.circuit { + w := &h.circuit[wI] + var val frontend.Variable + if w.IsInput() { + val = utils.FromInterface(ins[inI]) + inI++ + } else { + for gateInI, inWI := range w.Inputs { + h.gateIns[gateInI] = h.assignment[inWI][instanceI] + } + val = w.Gate.Evaluate(api, h.gateIns[:len(w.Inputs)]...) + } + if w.IsOutput() { + *outs[outI] = utils.FromInterface(val) + outI++ + } + h.assignment[wI] = append(h.assignment[wI], val) + } + return nil +} + +func (h *TestEngineHints) Prove(mod *big.Int, ins, outs []*big.Int) error { + + info, err := gkrtypes.StoringToSolvingInfo(*h.info, gkrgates.Get) + if err != nil { + return fmt.Errorf("failed to convert storing info to solving info: %w", err) + } + + if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + data := bls12377.NewSolvingData(info, bls12377.WithAssignment(h.assignment)) + return bls12377.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + data := bls12381.NewSolvingData(info, bls12381.WithAssignment(h.assignment)) + return bls12381.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { + data := bls24315.NewSolvingData(info, bls24315.WithAssignment(h.assignment)) + return bls24315.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { + data := bls24317.NewSolvingData(info, bls24317.WithAssignment(h.assignment)) + return bls24317.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BN254.ScalarField()) == 0 { + data := bn254.NewSolvingData(info, bn254.WithAssignment(h.assignment)) + return bn254.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { + data := bw6633.NewSolvingData(info, bw6633.WithAssignment(h.assignment)) + return bw6633.ProveHint(info.HashName, data)(mod, ins, outs) + } + if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { + data := bw6761.NewSolvingData(info, bw6761.WithAssignment(h.assignment)) + return bw6761.ProveHint(info.HashName, data)(mod, ins, outs) + } + + return errors.New("unsupported modulus") +} + +// GetAssignment returns the assignment for a particular wire and instance. +func (h *TestEngineHints) GetAssignment(_ *big.Int, ins []*big.Int, outs []*big.Int) error { + if len(ins) != 3 || !ins[0].IsUint64() || !ins[1].IsUint64() { + return errors.New("expected 3 inputs: wire index, instance index, and dummy output from the same instance") + } + if len(outs) != 1 { + return errors.New("expected 1 output: the value of the wire at the given instance") + } + *outs[0] = utils.FromInterface(h.assignment[ins[0].Uint64()][ins[1].Uint64()]) + return nil +} + +type gateAPI struct{ *big.Int } + +func (g gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + in1 := utils.FromInterface(i1) + in2 := utils.FromInterface(i2) + + in1.Add(&in1, &in2) + for _, v := range in { + inV := utils.FromInterface(v) + in1.Add(&in1, &inV) + } + return &in1 +} + +func (g gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + x, y := utils.FromInterface(b), utils.FromInterface(c) + x.Mul(&x, &y) + x.Mod(&x, g.Int) // reduce + y = utils.FromInterface(a) + x.Add(&x, &y) + return &x +} + +func (g gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + x.Neg(&x) + return &x +} + +func (g gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + y := utils.FromInterface(i2) + x.Sub(&x, &y) + for _, v := range in { + y = utils.FromInterface(v) + x.Sub(&x, &y) + } + return &x +} + +func (g gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + x := utils.FromInterface(i1) + y := utils.FromInterface(i2) + x.Mul(&x, &y) + for _, v := range in { + y = utils.FromInterface(v) + x.Mul(&x, &y) + } + x.Mod(&x, g.Int) // reduce + return &x +} + +func (g gateAPI) Println(a ...frontend.Variable) { + strings := make([]string, len(a)) + for i := range a { + if s, ok := a[i].(fmt.Stringer); ok { + strings[i] = s.String() + } else { + bigInt := utils.FromInterface(a[i]) + strings[i] = bigInt.String() + } + } +} diff --git a/internal/gkr/gkrinfo/info.go b/internal/gkr/gkrinfo/info.go index de9a845e8d..9d4d72e175 100644 --- a/internal/gkr/gkrinfo/info.go +++ b/internal/gkr/gkrinfo/info.go @@ -1,14 +1,6 @@ // Package gkrinfo contains serializable information capable of being saved in a SNARK circuit CS object. package gkrinfo -import ( - "fmt" - "sort" - - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/internal/utils" -) - type ( InputDependency struct { OutputWire int @@ -17,26 +9,16 @@ type ( } Wire struct { - Gate string - Inputs []int - NbUniqueOutputs int + Gate string + Inputs []int } Circuit []Wire - PrintInfo struct { - Values []any - Instance uint32 - IsGkrVar []bool - } StoringInfo struct { - Circuit Circuit - Dependencies [][]InputDependency // nil for input wires - NbInstances int - HashName string - SolveHintID solver.HintID - ProveHintID solver.HintID - Prints []PrintInfo + Circuit Circuit + NbInstances int + HashName string } Permutations struct { @@ -51,88 +33,12 @@ func (w Wire) IsInput() bool { return len(w.Inputs) == 0 } -func (w Wire) IsOutput() bool { - return w.NbUniqueOutputs == 0 -} - func (d *StoringInfo) NewInputVariable() int { i := len(d.Circuit) d.Circuit = append(d.Circuit, Wire{}) - d.Dependencies = append(d.Dependencies, nil) return i } -// Compile sorts the Circuit wires, their dependencies and the instances -func (d *StoringInfo) Compile(nbInstances int) (Permutations, error) { - - var p Permutations - d.NbInstances = nbInstances - // sort the instances to decide the order in which they are to be solved - instanceDeps := make([][]int, nbInstances) - for i := range d.Circuit { - for _, dep := range d.Dependencies[i] { - instanceDeps[dep.InputInstance] = append(instanceDeps[dep.InputInstance], dep.OutputInstance) - } - } - - p.SortedInstances, _ = utils.TopologicalSort(instanceDeps) - p.InstancesPermutation = utils.InvertPermutation(p.SortedInstances) - - // this whole circuit sorting is a bit of a charade. if things are built using an api, there's no way it could NOT already be topologically sorted - // worth keeping for future-proofing? - - inputs := utils.Map(d.Circuit, func(w Wire) []int { - return w.Inputs - }) - - var uniqueOuts [][]int - p.SortedWires, uniqueOuts = utils.TopologicalSort(inputs) - p.WiresPermutation = utils.InvertPermutation(p.SortedWires) - wirePermutationAt := utils.SliceAt(p.WiresPermutation) - sorted := make([]Wire, len(d.Circuit)) // TODO: Directly manipulate d.circuit instead - sortedDeps := make([][]InputDependency, len(d.Circuit)) - - // go through the wires in the sorted order and fix the input and dependency indices according to the permutations - for newI, oldI := range p.SortedWires { - oldW := d.Circuit[oldI] - - for depI := range d.Dependencies[oldI] { - dep := &d.Dependencies[oldI][depI] - dep.OutputWire = p.WiresPermutation[dep.OutputWire] - dep.InputInstance = p.InstancesPermutation[dep.InputInstance] - dep.OutputInstance = p.InstancesPermutation[dep.OutputInstance] - } - sort.Slice(d.Dependencies[oldI], func(i, j int) bool { - return d.Dependencies[oldI][i].InputInstance < d.Dependencies[oldI][j].InputInstance - }) - for i := 1; i < len(d.Dependencies[oldI]); i++ { - if d.Dependencies[oldI][i].InputInstance == d.Dependencies[oldI][i-1].InputInstance { - return p, fmt.Errorf("an input wire can only have one dependency per instance") - } - } // TODO: Check that dependencies and explicit assignments cover all instances - - sortedDeps[newI] = d.Dependencies[oldI] - sorted[newI] = Wire{ - Gate: oldW.Gate, - Inputs: utils.Map(oldW.Inputs, wirePermutationAt), - NbUniqueOutputs: len(uniqueOuts[oldI]), - } - } - - // re-arrange the prints - for i := range d.Prints { - for j, isVar := range d.Prints[i].IsGkrVar { - if isVar { - d.Prints[i].Values[j] = uint32(p.WiresPermutation[d.Prints[i].Values[j].(uint32)]) - } - } - } - - d.Circuit, d.Dependencies = sorted, sortedDeps - - return p, nil -} - func (d *StoringInfo) Is() bool { return d.Circuit != nil } diff --git a/internal/gkr/gkrtesting/gkrtesting.go b/internal/gkr/gkrtesting/gkrtesting.go index ce9ba88942..4c901f04a2 100644 --- a/internal/gkr/gkrtesting/gkrtesting.go +++ b/internal/gkr/gkrtesting/gkrtesting.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "github.com/consensys/gnark" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" @@ -35,10 +36,10 @@ func NewCache() *Cache { res = api.Mul(res, sum) // sum^7 return res - }, 2, 7, -1) + }, 2, 7, -1, gnark.Curves()) gates["select-input-3"] = gkrtypes.NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return in[2] - }, 3, 1, 0) + }, 3, 1, 0, gnark.Curves()) return &Cache{ circuits: make(map[string]gkrtypes.Circuit), diff --git a/internal/gkr/gkrtypes/types.go b/internal/gkr/gkrtypes/types.go index 7aed5ccd27..a76059ac73 100644 --- a/internal/gkr/gkrtypes/types.go +++ b/internal/gkr/gkrtypes/types.go @@ -3,7 +3,10 @@ package gkrtypes import ( "errors" "fmt" + "slices" + "github.com/consensys/gnark" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/utils" @@ -17,17 +20,24 @@ type Gate struct { nbIn int // number of inputs degree int // total degree of the polynomial solvableVar int // if there is a variable whose value can be uniquely determined from the value of the gate and the other inputs, its index, -1 otherwise + curves []ecc.ID // curves that the gate is allowed to be used over } -func NewGate(f gkr.GateFunction, nbIn int, degree int, solvableVar int) *Gate { +func NewGate(f gkr.GateFunction, nbIn int, degree int, solvableVar int, curves []ecc.ID) *Gate { + return &Gate{ evaluate: f, nbIn: nbIn, degree: degree, solvableVar: solvableVar, + curves: curves, } } +func (g *Gate) SupportsCurve(curve ecc.ID) bool { + return slices.Contains(g.curves, curve) +} + func (g *Gate) Evaluate(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return g.evaluate(api, in...) } @@ -133,49 +143,9 @@ func (c Circuit) MemoryRequirements(nbInstances int) []int { } type SolvingInfo struct { - Circuit Circuit - Dependencies [][]gkrinfo.InputDependency - NbInstances int - HashName string - Prints []gkrinfo.PrintInfo -} - -// Chunks returns intervals of instances that are independent of each other and can be solved in parallel -func (info *SolvingInfo) Chunks() []int { - res := make([]int, 0, 1) - lastSeenDependencyI := make([]int, len(info.Circuit)) - - for start, end := 0, 0; start != info.NbInstances; start = end { - end = info.NbInstances - endWireI := -1 - for wI := range info.Circuit { - deps := info.Dependencies[wI] - if wDepI := lastSeenDependencyI[wI]; wDepI < len(deps) && deps[wDepI].InputInstance < end { - end = deps[wDepI].InputInstance - endWireI = wI - } - } - if endWireI != -1 { - lastSeenDependencyI[endWireI]++ - } - res = append(res, end) - } - return res -} - -// AssignmentOffsets describes the input layout of the Solve hint, by returning -// for each wire, the index of the first hint input element corresponding to it. -func (info *SolvingInfo) AssignmentOffsets() []int { - c := info.Circuit - res := make([]int, len(c)+1) - for i := range c { - nbExplicitAssignments := 0 - if c[i].IsInput() { - nbExplicitAssignments = info.NbInstances - len(info.Dependencies[i]) - } - res[i+1] = res[i] + nbExplicitAssignments - } - return res + Circuit Circuit + NbInstances int + HashName string } // OutputsList for each wire, returns the set of indexes of wires it is input to. @@ -206,7 +176,7 @@ func (c Circuit) OutputsList() [][]int { return res } -func (c Circuit) SetNbUniqueOutputs() { +func (c Circuit) setNbUniqueOutputs() { for i := range c { c[i].NbUniqueOutputs = 0 @@ -254,6 +224,7 @@ func CircuitInfoToCircuit(info gkrinfo.Circuit, gateGetter func(name gkr.GateNam resCircuit := make(Circuit, len(info)) for i := range info { if info[i].Gate == "" && len(info[i].Inputs) == 0 { + resCircuit[i].Gate = Identity() // input wire continue } resCircuit[i].Inputs = info[i].Inputs @@ -262,17 +233,16 @@ func CircuitInfoToCircuit(info gkrinfo.Circuit, gateGetter func(name gkr.GateNam return nil, fmt.Errorf("gate \"%s\" not found", info[i].Gate) } } + resCircuit.setNbUniqueOutputs() return resCircuit, nil } func StoringToSolvingInfo(info gkrinfo.StoringInfo, gateGetter func(name gkr.GateName) *Gate) (SolvingInfo, error) { circuit, err := CircuitInfoToCircuit(info.Circuit, gateGetter) return SolvingInfo{ - Circuit: circuit, - NbInstances: info.NbInstances, - HashName: info.HashName, - Dependencies: info.Dependencies, - Prints: info.Prints, + Circuit: circuit, + NbInstances: info.NbInstances, + HashName: info.HashName, }, err } @@ -388,33 +358,33 @@ var ErrZeroFunction = errors.New("detected a zero function") func Identity() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return in[0] - }, 1, 1, 0) + }, 1, 1, 0, gnark.Curves()) } // Add2 gate: (x, y) -> x + y func Add2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Add(in[0], in[1]) - }, 2, 1, 0) + }, 2, 1, 0, gnark.Curves()) } // Sub2 gate: (x, y) -> x - y func Sub2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Sub(in[0], in[1]) - }, 2, 1, 0) + }, 2, 1, 0, gnark.Curves()) } // Neg gate: x -> -x func Neg() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Neg(in[0]) - }, 1, 1, 0) + }, 1, 1, 0, gnark.Curves()) } // Mul2 gate: (x, y) -> x * y func Mul2() *Gate { return NewGate(func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { return api.Mul(in[0], in[1]) - }, 2, 2, -1) + }, 2, 2, -1, gnark.Curves()) } diff --git a/internal/gkr/small_rational/gate_testing.go b/internal/gkr/small_rational/gate_testing.go deleted file mode 100644 index 1817cfbf6f..0000000000 --- a/internal/gkr/small_rational/gate_testing.go +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "fmt" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/internal/gkr/gkrtypes" - - "errors" - "slices" - - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/consensys/gnark/std/gkrapi/gkr" -) - -// IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { - fWrapped := api.convertFunc(f) - - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(small_rational.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := fWrapped(in...) - - x[i] = x0 - copy(in, x) - y1 := fWrapped(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := fWrapped(in...) - - y2.Sub(y2, y1) - y1.Sub(y1, y0) - - if !y2.Equal(y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = fWrapped(in...) - - x[i] = x0 - copy(in, x) - y1 = fWrapped(in...) - - y1.Sub(y1, y0) - - return y1.Equal(y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]small_rational.SmallRational, nbIn) - consts := make(small_rational.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - x := make(small_rational.Vector, degreeBound) - x.MustSetRandom() - for i := range x { - fIn[0] = x[i] - for j := range consts { - fIn[j+1].Mul(&x[i], &consts[j]) - } - p[i].Set(f(fIn...)) - } - - // obtain p's coefficients - p, err := interpolate(x, p) - if err != nil { - panic(err) - } - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -// FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { - fFr := api.convertFunc(f) - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := fFr.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, gkrtypes.ErrZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { - fFr := api.convertFunc(f) - if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return gkrtypes.ErrZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) -// Note that the runtime is O(len(X)³) -func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { - if len(X) != len(Y) { - return nil, errors.New("same length expected for X and Y") - } - - // solve the system of equations by Gaussian elimination - augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values - for i := range augmentedRows { - augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) - augmentedRows[i][0].SetOne() - augmentedRows[i][1].Set(&X[i]) - for j := 2; j < len(augmentedRows[i])-1; j++ { - augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) - } - augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) - } - - // make the upper triangle - for i := range len(augmentedRows) - 1 { - // use row i to eliminate the ith element in all rows below - var negInv small_rational.SmallRational - if augmentedRows[i][i].IsZero() { - return nil, errors.New("singular matrix") - } - negInv.Inverse(&augmentedRows[i][i]) - negInv.Neg(&negInv) - for j := i + 1; j < len(augmentedRows); j++ { - var c small_rational.SmallRational - c.Mul(&augmentedRows[j][i], &negInv) - // augmentedRows[j][i].SetZero() omitted - for k := i + 1; k < len(augmentedRows[i]); k++ { - var t small_rational.SmallRational - t.Mul(&augmentedRows[i][k], &c) - augmentedRows[j][k].Add(&augmentedRows[j][k], &t) - } - } - } - - // back substitution - res := make(polynomial.Polynomial, len(X)) - for i := len(augmentedRows) - 1; i >= 0; i-- { - res[i] = augmentedRows[i][len(augmentedRows[i])-1] - for j := i + 1; j < len(augmentedRows[i])-1; j++ { - var t small_rational.SmallRational - t.Mul(&res[j], &augmentedRows[i][j]) - res[i].Sub(&res[i], &t) - } - res[i].Div(&res[i], &augmentedRows[i][i]) - } - - return res, nil -} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e8e78f4b96..cdf62359f2 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -739,7 +739,7 @@ func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.V func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { var prod small_rational.SmallRational - prod.Add(cast(b), cast(c)) + prod.Mul(cast(b), cast(c)) res := cast(a) res.Add(res, &prod) return &res diff --git a/internal/utils/algo_utils.go b/internal/utils/algo_utils.go index f836625370..4bee19443e 100644 --- a/internal/utils/algo_utils.go +++ b/internal/utils/algo_utils.go @@ -24,6 +24,7 @@ func Permute[T any](slice []T, permutation []int) { } } +// Map returns [f(in[0]), f(in[1]), ..., f(in[len(in)-1])] func Map[T, S any](in []T, f func(T) S) []S { out := make([]S, len(in)) for i, t := range in { @@ -32,41 +33,6 @@ func Map[T, S any](in []T, f func(T) S) []S { return out } -func MapRange[S any](begin, end int, f func(int) S) []S { - out := make([]S, end-begin) - for i := begin; i < end; i++ { - out[i] = f(i) - } - return out -} - -func SliceAt[T any](slice []T) func(int) T { - return func(i int) T { - return slice[i] - } -} - -func SlicePtrAt[T any](slice []T) func(int) *T { - return func(i int) *T { - return &slice[i] - } -} - -func MapAt[K comparable, V any](mp map[K]V) func(K) V { - return func(k K) V { - return mp[k] - } -} - -// InvertPermutation input permutation must contain exactly 0, ..., len(permutation)-1 -func InvertPermutation(permutation []int) []int { - res := make([]int, len(permutation)) - for i := range permutation { - res[permutation[i]] = i - } - return res -} - // TODO: Move this to gnark-crypto and use it for gkr there as well // TopologicalSort takes a list of lists of dependencies and proposes a sorting of the lists in order of dependence. Such that for any wire, any one it depends on @@ -143,33 +109,11 @@ func (d *topSortData) markDone(i int) { } } -// BinarySearch looks for toFind in a sorted slice, and returns the index at which it either is or would be were it to be inserted. -func BinarySearch(slice []int, toFind int) int { - var start int - for end := len(slice); start != end; { - mid := (start + end) / 2 - if toFind >= slice[mid] { - start = mid - } - if toFind <= slice[mid] { - end = mid - } +// SliceOfRefs returns [&slice[0], &slice[1], ..., &slice[len(slice)-1]] +func SliceOfRefs[T any](slice []T) []*T { + res := make([]*T, len(slice)) + for i := range slice { + res[i] = &slice[i] } - return start -} - -// BinarySearchFunc looks for toFind in an increasing function of domain 0 ... (end-1), and returns the index at which it either is or would be were it to be inserted. -func BinarySearchFunc(eval func(int) int, end int, toFind int) int { - var start int - for start != end { - mid := (start + end) / 2 - val := eval(mid) - if toFind >= val { - start = mid - } - if toFind <= val { - end = mid - } - } - return start + return res } diff --git a/internal/utils/slices.go b/internal/utils/slices.go index dd2e2db31f..bdd86119fa 100644 --- a/internal/utils/slices.go +++ b/internal/utils/slices.go @@ -16,3 +16,16 @@ func References[T any](v []T) []*T { } return res } + +// ExtendRepeatLast extends a non-empty slice s by repeating the last element until it reaches the length n. +func ExtendRepeatLast[T any](s []T, n int) []T { + if n <= len(s) { + return s[:n] + } + res := make([]T, n) + copy(res, s) + for i := len(s); i < n; i++ { + res[i] = res[i-1] + } + return res +} diff --git a/internal/utils/slices_test.go b/internal/utils/slices_test.go new file mode 100644 index 0000000000..f61ec18fed --- /dev/null +++ b/internal/utils/slices_test.go @@ -0,0 +1,25 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtendRepeatLast(t *testing.T) { + // normal case + s := []int{1, 2, 3} + u := ExtendRepeatLast(s, 5) + assert.Equal(t, []int{1, 2, 3, 3, 3}, u) + + // don't overwrite super-slice + s = []int{1, 2, 3} + u = ExtendRepeatLast(s[:1], 2) + assert.Equal(t, []int{1, 1}, u) + assert.Equal(t, []int{1, 2, 3}, s) + + // trim if n < len(s) + s = []int{1, 2, 3} + u = ExtendRepeatLast(s, 2) + assert.Equal(t, []int{1, 2}, u) +} diff --git a/std/gkrapi/api.go b/std/gkrapi/api.go index 771613ce0d..5afc886e5f 100644 --- a/std/gkrapi/api.go +++ b/std/gkrapi/api.go @@ -23,7 +23,6 @@ func (api *API) NamedGate(gate gkr.GateName, in ...gkr.Variable) gkr.Variable { Inputs: utils.Map(in, frontendVarToInt), }) api.assignments = append(api.assignments, nil) - api.toStore.Dependencies = append(api.toStore.Dependencies, nil) // formality. Dependencies are only defined for input vars. return gkr.Variable(len(api.toStore.Circuit) - 1) } @@ -59,25 +58,3 @@ func (api *API) Sub(i1, i2 gkr.Variable) gkr.Variable { func (api *API) Mul(i1, i2 gkr.Variable) gkr.Variable { return api.namedGate2PlusIn(gkr.Mul2, i1, i2) } - -// Println writes to the standard output. -// instance determines which values are chosen for gkr.Variable input. -func (api *API) Println(instance int, a ...any) { - isVar := make([]bool, len(a)) - vals := make([]any, len(a)) - for i := range a { - v, ok := a[i].(gkr.Variable) - isVar[i] = ok - if ok { - vals[i] = uint32(v) - } else { - vals[i] = a[i] - } - } - - api.toStore.Prints = append(api.toStore.Prints, gkrinfo.PrintInfo{ - Values: vals, - Instance: uint32(instance), - IsGkrVar: isVar, - }) -} diff --git a/std/gkrapi/api_test.go b/std/gkrapi/api_test.go index 5823c687dd..5895c66597 100644 --- a/std/gkrapi/api_test.go +++ b/std/gkrapi/api_test.go @@ -1,7 +1,6 @@ package gkrapi import ( - "bytes" "fmt" "hash" "math/big" @@ -16,10 +15,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" gcHash "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark/backend/groth16" - "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/internal/gkr/gkrinfo" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/gkrapi/gkr" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/test" @@ -27,7 +25,7 @@ import ( "github.com/stretchr/testify/require" ) -// compressThreshold --> if linear expressions are larger than this, the frontend will introduce +// compressThreshold → if linear expressions are larger than this, the frontend will introduce // intermediate constraints. The lower this number is, the faster compile time should be (to a point) // but resulting circuit will have more constraints (slower proving time). const compressThreshold = 1000 @@ -39,23 +37,24 @@ type doubleNoDependencyCircuit struct { func (c *doubleNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } + x := gkrApi.NewInput() z := gkrApi.Add(x, x) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { + + gkrCircuit, err := gkrApi.Compile(api, c.hashName) + if err != nil { return err } - Z := solution.Export(z) - for i := range Z { - api.AssertIsEqual(Z[i], api.Mul(2, c.X[i])) + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(2, c.X[i])) } - - return solution.Verify(c.hashName) + return nil } func TestDoubleNoDependencyCircuit(t *testing.T) { @@ -87,23 +86,24 @@ type sqNoDependencyCircuit struct { func (c *sqNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } + x := gkrApi.NewInput() z := gkrApi.Mul(x, x) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { + + gkrCircuit, err := gkrApi.Compile(api, c.hashName) + if err != nil { return err } - Z := solution.Export(z) - for i := range Z { - api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.X[i])) + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(c.X[i], c.X[i])) } - - return solution.Verify(c.hashName) + return nil } func TestSqNoDependencyCircuit(t *testing.T) { @@ -134,29 +134,26 @@ type mulNoDependencyCircuit struct { func (c *mulNoDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x, y gkr.Variable - var err error - if x, err = gkrApi.Import(c.X); err != nil { - return err - } - if y, err = gkrApi.Import(c.Y); err != nil { - return err - } - gkrApi.Println(0, "values of x and y in instance number", 0, x, y) - + x := gkrApi.NewInput() + y := gkrApi.NewInput() z := gkrApi.Mul(x, y) - gkrApi.Println(1, "value of z in instance number", 1, z) - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { + + gkrCircuit, err := gkrApi.Compile(api, c.hashName) + if err != nil { return err } - Z := solution.Export(z) + instanceIn := make(map[gkr.Variable]frontend.Variable) for i := range c.X { - api.AssertIsEqual(Z[i], api.Mul(c.X[i], c.Y[i])) + instanceIn[x] = c.X[i] + instanceIn[y] = c.Y[i] + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + api.AssertIsEqual(instanceOut[z], api.Mul(c.Y[i], c.X[i])) } - - return solution.Verify(c.hashName) + return nil } func TestMulNoDependency(t *testing.T) { @@ -190,91 +187,71 @@ func TestMulNoDependency(t *testing.T) { } type mulWithDependencyCircuit struct { - XLast frontend.Variable + XFirst frontend.Variable Y []frontend.Variable hashName string } func (c *mulWithDependencyCircuit) Define(api frontend.API) error { gkrApi := New() - var x, y gkr.Variable - var err error - X := make([]frontend.Variable, len(c.Y)) - X[len(c.Y)-1] = c.XLast - if x, err = gkrApi.Import(X); err != nil { - return err - } - if y, err = gkrApi.Import(c.Y); err != nil { + x := gkrApi.NewInput() // x is the state variable + y := gkrApi.NewInput() + z := gkrApi.Mul(x, y) + + gkrCircuit, err := gkrApi.Compile(api, c.hashName) + if err != nil { return err } - z := gkrApi.Mul(x, y) + state := c.XFirst + instanceIn := make(map[gkr.Variable]frontend.Variable) - for i := len(X) - 1; i > 0; i-- { - gkrApi.Series(x, z, i-1, i) - } + for i := range c.Y { + instanceIn[x] = state + instanceIn[y] = c.Y[i] - var solution Solution - if solution, err = gkrApi.Solve(api); err != nil { - return err - } - X = solution.Export(x) - Y := solution.Export(y) - Z := solution.Export(z) + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } - lastI := len(X) - 1 - api.AssertIsEqual(Z[lastI], api.Mul(c.XLast, Y[lastI])) - for i := 0; i < lastI; i++ { - api.AssertIsEqual(Z[i], api.Mul(Z[i+1], Y[i])) + api.AssertIsEqual(instanceOut[z], api.Mul(state, c.Y[i])) + state = instanceOut[z] // update state for the next iteration } - return solution.Verify(c.hashName) + return nil } func TestSolveMulWithDependency(t *testing.T) { assert := test.NewAssert(t) assignment := mulWithDependencyCircuit{ - XLast: 1, - Y: []frontend.Variable{3, 2}, + XFirst: 1, + Y: []frontend.Variable{3, 2}, } circuit := mulWithDependencyCircuit{Y: make([]frontend.Variable, len(assignment.Y)), hashName: "-20"} assert.CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BN254)) } func TestApiMul(t *testing.T) { - var ( - x gkr.Variable - y gkr.Variable - z gkr.Variable - err error - ) api := New() - x, err = api.Import([]frontend.Variable{nil, nil}) - require.NoError(t, err) - y, err = api.Import([]frontend.Variable{nil, nil}) - require.NoError(t, err) - z = api.Mul(x, y) + x := api.NewInput() + y := api.NewInput() + z := api.Mul(x, y) assertSliceEqual(t, api.toStore.Circuit[z].Inputs, []int{int(x), int(y)}) // TODO: Find out why assert.Equal gives false positives ( []*Wire{x,x} as second argument passes when it shouldn't ) } func BenchmarkMiMCMerkleTree(b *testing.B) { - depth := 14 - bottom := make([]frontend.Variable, 1<= 0; d-- { - for i := 0; i < 1< 1 { + nextLayer := curLayer[:len(curLayer)/2] -func init() { - registerMiMCGate() + for i := range nextLayer { + instanceIn[x] = curLayer[2*i] + instanceIn[y] = curLayer[2*i+1] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + nextLayer[i] = instanceOut[z] // store the result of the hash + } + + curLayer = nextLayer + } + return nil } -func registerMiMCGate() { - // register mimc gate - panicIfError(gkrgates.Register(func(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { - mimcSnarkTotalCalls++ +func mimcGate(api gkr.GateAPI, input ...frontend.Variable) frontend.Variable { + mimcSnarkTotalCalls++ - if len(input) != 2 { - panic("mimc has fan-in 2") - } - sum := api.Add(input[0], input[1] /*, m.Ark*/) + if len(input) != 2 { + panic("mimc has fan-in 2") + } + sum := api.Add(input[0], input[1] /*, m.Ark*/) - sumCubed := api.Mul(sum, sum, sum) // sum^3 - return api.Mul(sumCubed, sumCubed, sum) - }, 2, gkrgates.WithDegree(7), gkrgates.WithName("MIMC"))) + sumCubed := api.Mul(sum, sum, sum) // sum³ + return api.Mul(sumCubed, sumCubed, sum) } type constPseudoHash int @@ -406,8 +377,6 @@ func (c constPseudoHash) Write(...frontend.Variable) {} func (c constPseudoHash) Reset() {} -var mimcFrTotalCalls = 0 - type mimcNoGkrCircuit struct { X []frontend.Variable Y []frontend.Variable @@ -456,26 +425,36 @@ type mimcNoDepCircuit struct { } func (c *mimcNoDepCircuit) Define(api frontend.API) error { - _gkr := New() - x, err := _gkr.Import(c.X) - if err != nil { - return err + // define the circuit + gkrApi := New() + x := gkrApi.NewInput() + y := gkrApi.NewInput() + + if c.mimcDepth < 1 { + return fmt.Errorf("mimcDepth must be at least 1, got %d", c.mimcDepth) } - var ( - y gkr.Variable - solution Solution - ) - if y, err = _gkr.Import(c.Y); err != nil { + + z := y + for range c.mimcDepth { + z = gkrApi.Gate(mimcGate, x, z) + } + + gkrCircuit, err := gkrApi.Compile(api, c.hashName) + if err != nil { return err } - z := _gkr.NamedGate("MIMC", x, y) + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[x] = c.X[i] + instanceIn[y] = c.Y[i] - if solution, err = _gkr.Solve(api); err != nil { - return err + _, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } } - Z := solution.Export(z) - return solution.Verify(c.hashName, Z...) + return nil } func mimcNoDepCircuits(mimcDepth, nbInstances int, hashName string) (circuit, assignment frontend.Circuit) { @@ -557,64 +536,6 @@ func mimcNoGkrCircuits(mimcDepth, nbInstances int) (circuit, assignment frontend return } -func TestSolveInTestEngine(t *testing.T) { - assignment := testSolveInTestEngineCircuit{ - X: []frontend.Variable{2, 3, 4, 5, 6, 7, 8, 9}, - } - circuit := testSolveInTestEngineCircuit{ - X: make([]frontend.Variable, len(assignment.X)), - } - - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BN254.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS24_315.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_381.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS24_317.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BW6_633.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) - require.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField())) -} - -type testSolveInTestEngineCircuit struct { - X []frontend.Variable -} - -func (c *testSolveInTestEngineCircuit) Define(api frontend.API) error { - gkrBn254 := New() - x, err := gkrBn254.Import(c.X) - if err != nil { - return err - } - Y := make([]frontend.Variable, len(c.X)) - Y[0] = 1 - y, err := gkrBn254.Import(Y) - if err != nil { - return err - } - - z := gkrBn254.Mul(x, y) - - for i := range len(c.X) - 1 { - gkrBn254.Series(y, z, i+1, i) - } - - assignments := gkrBn254.SolveInTestEngine(api) - - product := frontend.Variable(1) - for i := range c.X { - api.AssertIsEqual(assignments[y][i], product) - product = api.Mul(product, c.X[i]) - api.AssertIsEqual(assignments[z][i], product) - } - - return nil -} - -func panicIfError(err error) { - if err != nil { - panic(err) - } -} - func assertSliceEqual[T comparable](t *testing.T, expected, seen []T) { assert.Equal(t, len(expected), len(seen)) for i := range seen { @@ -636,7 +557,7 @@ func (m MiMCCipherGate) Evaluate(api frontend.API, input ...frontend.Variable) f } sum := api.Add(input[0], input[1], m.Ark) - sumCubed := api.Mul(sum, sum, sum) // sum^3 + sumCubed := api.Mul(sum, sum, sum) // sum³ return api.Mul(sumCubed, sumCubed, sum) } @@ -689,46 +610,105 @@ func init() { } } -func ExamplePrintln() { +// pow3Circuit computes x⁴ and also checks the correctness of intermediate value x². +// This is to demonstrate the use of [Circuit.GetValue] and should not be done +// in production code, as it negates the performance benefits of using GKR in the first place. +type pow4Circuit struct { + X []frontend.Variable +} - circuit := &mulNoDependencyCircuit{ - X: make([]frontend.Variable, 2), - Y: make([]frontend.Variable, 2), - hashName: "MIMC", +func (c *pow4Circuit) Define(api frontend.API) error { + gkrApi := New() + x := gkrApi.NewInput() + x2 := gkrApi.Mul(x, x) // x² + x4 := gkrApi.Mul(x2, x2) // x⁴ + + gkrCircuit, err := gkrApi.Compile(api, "MIMC") + if err != nil { + return err + } + + for i := range c.X { + instanceIn := make(map[gkr.Variable]frontend.Variable) + instanceIn[x] = c.X[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return fmt.Errorf("failed to add instance: %w", err) + } + + api.AssertIsEqual(gkrCircuit.GetValue(x, i), c.X[i]) // x + + v := api.Mul(c.X[i], c.X[i]) // x² + api.AssertIsEqual(gkrCircuit.GetValue(x2, i), v) // x² + + v = api.Mul(v, v) // x⁴ + api.AssertIsEqual(gkrCircuit.GetValue(x4, i), v) // x⁴ + api.AssertIsEqual(instanceOut[x4], v) // x⁴ } - assignment := &mulNoDependencyCircuit{ - X: []frontend.Variable{10, 11}, - Y: []frontend.Variable{12, 13}, + return nil +} + +func TestPow4Circuit_GetValue(t *testing.T) { + assignment := pow4Circuit{ + X: []frontend.Variable{1, 2, 3, 4, 5}, } - field := ecc.BN254.ScalarField() + circuit := pow4Circuit{ + X: make([]frontend.Variable, len(assignment.X)), + } - // with test engine - err := test.IsSolved(circuit, assignment, field) - panicIfError(err) + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} - // with groth16 / serialized CS - firstCs, err := frontend.Compile(field, r1cs.NewBuilder, circuit) - panicIfError(err) +func TestWitnessExtend(t *testing.T) { + circuit := doubleNoDependencyCircuit{X: make([]frontend.Variable, 3), hashName: "-1"} + assignment := doubleNoDependencyCircuit{X: []frontend.Variable{0, 0, 1}} - var bb bytes.Buffer - _, err = firstCs.WriteTo(&bb) - panicIfError(err) - cs := groth16.NewCS(ecc.BN254) - _, err = cs.ReadFrom(&bb) - panicIfError(err) + cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(t, err) - pk, _, err := groth16.Setup(cs) - panicIfError(err) - w, err := frontend.NewWitness(assignment, field) - panicIfError(err) - _, err = groth16.Prove(cs, pk, w) - panicIfError(err) - - // Output: - // values of x and y in instance number 0 10 12 - // value of z in instance number 1 143 - // values of x and y in instance number 0 10 12 - // value of z in instance number 1 143 + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + require.NoError(t, err) + + _, err = cs.Solve(witness) + require.NoError(t, err) +} + +func TestSingleInstance(t *testing.T) { + circuit := mimcNoDepCircuit{ + X: make([]frontend.Variable, 1), + Y: make([]frontend.Variable, 1), + mimcDepth: 2, + hashName: "MIMC", + } + assignment := mimcNoDepCircuit{ + X: []frontend.Variable{10}, + Y: []frontend.Variable{2}, + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +func TestNoInstance(t *testing.T) { + var circuit testNoInstanceCircuit + assignment := testNoInstanceCircuit{0} + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +type testNoInstanceCircuit struct { + Dummy frontend.Variable // Plonk prover would fail on an empty witness +} + +func (c *testNoInstanceCircuit) Define(api frontend.API) error { + gkrApi := New() + x := gkrApi.NewInput() + y := gkrApi.Mul(x, x) + gkrApi.Mul(x, y) + + gkrApi.Compile(api, "MIMC") + + return nil } diff --git a/std/gkrapi/compile.go b/std/gkrapi/compile.go index 1e7784fb97..6c221e3ce5 100644 --- a/std/gkrapi/compile.go +++ b/std/gkrapi/compile.go @@ -5,16 +5,18 @@ import ( "fmt" "math/bits" - "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint/solver/gkrgates" "github.com/consensys/gnark/frontend" gadget "github.com/consensys/gnark/internal/gkr" + "github.com/consensys/gnark/internal/gkr/gkrhints" "github.com/consensys/gnark/internal/gkr/gkrinfo" "github.com/consensys/gnark/internal/gkr/gkrtypes" "github.com/consensys/gnark/internal/utils" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/gkrapi/gkr" "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/multicommit" ) type circuitDataForSnark struct { @@ -22,18 +24,17 @@ type circuitDataForSnark struct { assignments gkrtypes.WireAssignment } -type Solution struct { - toStore gkrinfo.StoringInfo - assignments gkrtypes.WireAssignment - parentApi frontend.API - permutations gkrinfo.Permutations -} - -func (api *API) nbInstances() int { - if len(api.assignments) == 0 { - return -1 - } - return api.assignments.NbInstances() +type InitialChallengeGetter func() []frontend.Variable + +// Circuit represents a GKR circuit. +type Circuit struct { + toStore gkrinfo.StoringInfo + assignments gkrtypes.WireAssignment + getInitialChallenges InitialChallengeGetter // optional getter for the initial Fiat-Shamir challenge + ins []gkr.Variable + outs []gkr.Variable + api frontend.API // the parent API used for hints + hints *gkrhints.TestEngineHints // hints for the GKR circuit, used for testing purposes } // New creates a new GKR API @@ -41,200 +42,237 @@ func New() *API { return &API{} } -// log2 returns -1 if x is not a power of 2 -func log2(x uint) int { - if bits.OnesCount(x) != 1 { - return -1 - } - return bits.TrailingZeros(x) +// NewInput creates a new input variable. +func (api *API) NewInput() gkr.Variable { + return gkr.Variable(api.toStore.NewInputVariable()) } -// Series like in an electric circuit, binds an input of an instance to an output of another -func (api *API) Series(input, output gkr.Variable, inputInstance, outputInstance int) *API { - if api.assignments[input][inputInstance] != nil { - panic("dependency attempting to override explicit value assignment") - } - api.toStore.Dependencies[input] = - append(api.toStore.Dependencies[input], gkrinfo.InputDependency{ - OutputWire: int(output), - OutputInstance: outputInstance, - InputInstance: inputInstance, - }) - return api -} +type CompileOption func(*Circuit) -// Import creates a new input variable, whose values across all instances are given by assignment. -// If the value in an instance depends on an output of another instance, leave the corresponding index in assignment nil and use Series to specify the dependency. -func (api *API) Import(assignment []frontend.Variable) (gkr.Variable, error) { - nbInstances := len(assignment) - logNbInstances := log2(uint(nbInstances)) - if logNbInstances == -1 { - return -1, errors.New("number of assignments must be a power of 2") +// WithInitialChallenge provides a getter for the initial Fiat-Shamir challenge. +// If not provided, the initial challenge will be a commitment to all the input and output values of the circuit. +func WithInitialChallenge(getInitialChallenge InitialChallengeGetter) CompileOption { + return func(c *Circuit) { + c.getInitialChallenges = getInitialChallenge } - - if currentNbInstances := api.nbInstances(); currentNbInstances != -1 && currentNbInstances != nbInstances { - return -1, errors.New("number of assignments must be consistent across all variables") - } - api.assignments = append(api.assignments, assignment) - return gkr.Variable(api.toStore.NewInputVariable()), nil } -// appendNonNil filters out nil values from src and appends the non-nil values to dst. -// i.e. dst = [0,1], src = [nil, 2, nil, 3] => dst = [0,1,2,3]. -func appendNonNil(dst *[]frontend.Variable, src []frontend.Variable) { - for i := range src { - if src[i] != nil { - *dst = append(*dst, src[i]) - } +// Compile finalizes the GKR circuit. +// From this point on, the circuit cannot be modified. +// But instances can be added to the circuit. +func (api *API) Compile(parentApi frontend.API, fiatshamirHashName string, options ...CompileOption) (*Circuit, error) { + res := Circuit{ + toStore: api.toStore, + assignments: make(gkrtypes.WireAssignment, len(api.toStore.Circuit)), + api: parentApi, } -} -// Solve finalizes the GKR circuit and returns the output variables in the order created -func (api *API) Solve(parentApi frontend.API) (Solution, error) { + res.toStore.HashName = fiatshamirHashName - var p gkrinfo.Permutations var err error - if p, err = api.toStore.Compile(api.assignments.NbInstances()); err != nil { - return Solution{}, err + res.hints, err = gkrhints.NewTestEngineHints(&res.toStore) + if err != nil { + return nil, fmt.Errorf("failed to call GKR hints: %w", err) } - api.assignments.Permute(p) - - nbInstances := api.toStore.NbInstances - circuit := api.toStore.Circuit - solveHintNIn := 0 - solveHintNOut := 0 + for _, opt := range options { + opt(&res) + } - for i := range circuit { - v := &circuit[i] - in, out := v.IsInput(), v.IsOutput() - if in && out { - return Solution{}, fmt.Errorf("unused input (variable #%d)", i) + notOut := make([]bool, len(res.toStore.Circuit)) + for i := range res.toStore.Circuit { + if res.toStore.Circuit[i].IsInput() { + res.ins = append(res.ins, gkr.Variable(i)) + } + for _, inWI := range res.toStore.Circuit[i].Inputs { + notOut[inWI] = true } + } + + if len(res.ins) == len(res.toStore.Circuit) { + return nil, errors.New("circuit has no non-input wires") + } - if in { - solveHintNIn += nbInstances - len(api.toStore.Dependencies[i]) - } else if out { - solveHintNOut += nbInstances + for i := range res.toStore.Circuit { + if !notOut[i] { + res.outs = append(res.outs, gkr.Variable(i)) } } - // arrange inputs wire first, then in the order solved - ins := make([]frontend.Variable, 0, solveHintNIn) - for i := range circuit { - if circuit[i].IsInput() { - appendNonNil(&ins, api.assignments[i]) + parentApi.Compiler().Defer(res.finalize) + + return &res, nil +} + +// AddInstance adds a new instance to the GKR circuit, returning the values of output variables for the instance. +func (c *Circuit) AddInstance(input map[gkr.Variable]frontend.Variable) (map[gkr.Variable]frontend.Variable, error) { + if len(input) != len(c.ins) { + for k := range input { + if k >= gkr.Variable(len(c.toStore.Circuit)) { + return nil, fmt.Errorf("variable %d is out of bounds (max %d)", k, len(c.toStore.Circuit)-1) + } + if !c.toStore.Circuit[k].IsInput() { + return nil, fmt.Errorf("value provided for non-input variable %d", k) + } + } + } + hintIn := make([]frontend.Variable, 1+len(c.ins)) // first input denotes the instance number + hintIn[0] = c.toStore.NbInstances + for hintInI, wI := range c.ins { + if inV, ok := input[wI]; !ok { + return nil, fmt.Errorf("missing entry for input variable %d", wI) + } else { + hintIn[hintInI+1] = inV + c.assignments[wI] = append(c.assignments[wI], inV) } } - solveHintPlaceholder := SolveHintPlaceholder(api.toStore) - outsSerialized, err := parentApi.Compiler().NewHint(solveHintPlaceholder, solveHintNOut, ins...) - api.toStore.SolveHintID = solver.GetHintID(solveHintPlaceholder) + outsSerialized, err := c.api.Compiler().NewHint(c.hints.Solve, len(c.outs), hintIn...) if err != nil { - return Solution{}, err + return nil, fmt.Errorf("failed to call solve hint: %w", err) + } + c.toStore.NbInstances++ + res := make(map[gkr.Variable]frontend.Variable, len(c.outs)) + for i, v := range c.outs { + res[v] = outsSerialized[i] + c.assignments[v] = append(c.assignments[v], outsSerialized[i]) } - for i := range circuit { - if circuit[i].IsOutput() { - api.assignments[i] = outsSerialized[:nbInstances] - outsSerialized = outsSerialized[nbInstances:] + return res, nil +} + +// finalize encodes the verification circuitry for the GKR circuit. +func (c *Circuit) finalize(api frontend.API) error { + if api != c.api { + panic("api mismatch") + } + + // if the circuit is empty or with no instances, there is nothing to do. + if len(c.outs) == 0 || len(c.assignments[0]) == 0 { // wire 0 is always an input wire + return nil + } + + // pad instances to the next power of 2 + nbPaddedInstances := int(ecc.NextPowerOfTwo(uint64(c.toStore.NbInstances))) + // pad instances to the next power of 2 by repeating the last instance + if c.toStore.NbInstances < nbPaddedInstances && c.toStore.NbInstances > 0 { + for _, wI := range c.ins { + c.assignments[wI] = utils.ExtendRepeatLast(c.assignments[wI], nbPaddedInstances) + } + for _, wI := range c.outs { + c.assignments[wI] = utils.ExtendRepeatLast(c.assignments[wI], nbPaddedInstances) } } - for i := range circuit { - for _, dep := range api.toStore.Dependencies[i] { - api.assignments[i][dep.InputInstance] = api.assignments[dep.OutputWire][dep.OutputInstance] + if err := api.(gkrinfo.ConstraintSystem).SetGkrInfo(c.toStore); err != nil { + return err + } + + // if the circuit consists of only one instance, directly solve the circuit + if len(c.assignments[c.ins[0]]) == 1 { + circuit, err := gkrtypes.CircuitInfoToCircuit(c.toStore.Circuit, gkrgates.Get) + if err != nil { + return fmt.Errorf("failed to convert GKR info to circuit: %w", err) } + gateIn := make([]frontend.Variable, circuit.MaxGateNbIn()) + for wI, w := range circuit { + if w.IsInput() { + continue + } + for inI, inWI := range w.Inputs { + gateIn[inI] = c.assignments[inWI][0] // take the first (only) instance + } + res := w.Gate.Evaluate(api, gateIn[:len(w.Inputs)]...) + if w.IsOutput() { + api.AssertIsEqual(res, c.assignments[wI][0]) + } else { + c.assignments[wI] = append(c.assignments[wI], res) + } + } + return nil } - return Solution{ - toStore: api.toStore, - assignments: api.assignments, - parentApi: parentApi, - permutations: p, - }, nil -} + if c.getInitialChallenges != nil { + return c.verify(api, c.getInitialChallenges()) + } -// Export returns the values of an output variable across all instances -func (s Solution) Export(v gkr.Variable) []frontend.Variable { - return utils.Map(s.permutations.SortedInstances, utils.SliceAt(s.assignments[v])) + // default initial challenge is a commitment to all input and output values + insOuts := make([]frontend.Variable, 0, (len(c.ins)+len(c.outs))*len(c.assignments[c.ins[0]])) + for _, in := range c.ins { + insOuts = append(insOuts, c.assignments[in]...) + } + for _, out := range c.outs { + insOuts = append(insOuts, c.assignments[out]...) + } + + multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { + return c.verify(api, []frontend.Variable{commitment}) + }, insOuts...) + + return nil } -// Verify encodes the verification circuitry for the GKR circuit -func (s Solution) Verify(hashName string, initialChallenges ...frontend.Variable) error { +func (c *Circuit) verify(api frontend.API, initialChallenges []frontend.Variable) error { + forSnark, err := newCircuitDataForSnark(utils.FieldToCurve(api.Compiler().Field()), c.toStore, c.assignments) + if err != nil { + return fmt.Errorf("failed to create circuit data for snark: %w", err) + } + + hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" + firstOutputAssignment := c.assignments[c.outs[0]] + hintIns[0] = firstOutputAssignment[len(firstOutputAssignment)-1] // take the last output of the first output wire + + copy(hintIns[1:], initialChallenges) + var ( - err error proofSerialized []frontend.Variable proof gadget.Proof ) - forSnark := newCircuitDataForSnark(s.toStore, s.assignments) - logNbInstances := log2(uint(s.assignments.NbInstances())) - - hintIns := make([]frontend.Variable, len(initialChallenges)+1) // hack: adding one of the outputs of the solve hint to ensure "prove" is called after "solve" - for i, w := range s.toStore.Circuit { - if w.IsOutput() { - hintIns[0] = s.assignments[i][len(s.assignments[i])-1] - break - } - } - copy(hintIns[1:], initialChallenges) - - proveHintPlaceholder := ProveHintPlaceholder(hashName) - if proofSerialized, err = s.parentApi.Compiler().NewHint( - proveHintPlaceholder, gadget.ProofSize(forSnark.circuit, logNbInstances), hintIns...); err != nil { + if proofSerialized, err = api.Compiler().NewHint( + c.hints.Prove, gadget.ProofSize(forSnark.circuit, bits.TrailingZeros(uint(len(c.assignments[0])))), hintIns...); err != nil { return err } - s.toStore.ProveHintID = solver.GetHintID(proveHintPlaceholder) - forSnarkSorted := utils.MapRange(0, len(s.toStore.Circuit), slicePtrAt(forSnark.circuit)) + forSnarkSorted := utils.SliceOfRefs(forSnark.circuit) if proof, err = gadget.DeserializeProof(forSnarkSorted, proofSerialized); err != nil { return err } var hsh hash.FieldHasher - if hsh, err = hash.GetFieldHasher(hashName, s.parentApi); err != nil { + if hsh, err = hash.GetFieldHasher(c.toStore.HashName, api); err != nil { return err } - s.toStore.HashName = hashName - - err = gadget.Verify(s.parentApi, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) - if err != nil { - return err - } - - return s.parentApi.(gkrinfo.ConstraintSystem).SetGkrInfo(s.toStore) -} -func slicePtrAt[T any](slice []T) func(int) *T { - return func(i int) *T { - return &slice[i] - } + return gadget.Verify(api, forSnark.circuit, forSnark.assignments, proof, fiatshamir.WithHash(hsh, initialChallenges...), gadget.WithSortedCircuit(forSnarkSorted)) } -func ite[T any](condition bool, ifNot, IfSo T) T { - if condition { - return IfSo +func newCircuitDataForSnark(curve ecc.ID, info gkrinfo.StoringInfo, assignment gkrtypes.WireAssignment) (circuitDataForSnark, error) { + circuit, err := gkrtypes.CircuitInfoToCircuit(info.Circuit, gkrgates.Get) + if err != nil { + return circuitDataForSnark{}, fmt.Errorf("failed to convert GKR info to circuit: %w", err) } - return ifNot -} - -func newCircuitDataForSnark(info gkrinfo.StoringInfo, assignment gkrtypes.WireAssignment) circuitDataForSnark { - circuit := make(gkrtypes.Circuit, len(info.Circuit)) - snarkAssignment := make(gkrtypes.WireAssignment, len(info.Circuit)) for i := range circuit { - w := info.Circuit[i] - circuit[i] = gkrtypes.Wire{ - Gate: gkrgates.Get(ite(w.IsInput(), gkr.GateName(w.Gate), gkr.Identity)), - Inputs: w.Inputs, - NbUniqueOutputs: w.NbUniqueOutputs, + if !circuit[i].Gate.SupportsCurve(curve) { + return circuitDataForSnark{}, fmt.Errorf("gate \"%s\" not usable over curve \"%s\"", info.Circuit[i].Gate, curve) } - snarkAssignment[i] = assignment[i] } + return circuitDataForSnark{ circuit: circuit, - assignments: snarkAssignment, + assignments: assignment, + }, nil +} + +// GetValue is a debugging utility returning the value of variable v at instance i. +// While v can be an input or output variable, GetValue is most useful for querying intermediate values in the circuit. +func (c *Circuit) GetValue(v gkr.Variable, i int) frontend.Variable { + // last input to ensure the solver's work is done before GetAssignment is called + res, err := c.api.Compiler().NewHint(c.hints.GetAssignment, 1, int(v), i, c.assignments[c.outs[0]][i]) + if err != nil { + panic(err) } + return res[0] } diff --git a/std/gkrapi/compile_test.go b/std/gkrapi/compile_test.go deleted file mode 100644 index a0ca992ed4..0000000000 --- a/std/gkrapi/compile_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package gkrapi - -import ( - "testing" - - "github.com/consensys/gnark/internal/gkr/gkrinfo" - "github.com/stretchr/testify/assert" -) - -func TestCompile2Cycles(t *testing.T) { - var d = gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - nil, - { - { - OutputWire: 0, - OutputInstance: 1, - InputInstance: 0, - }, - }, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{1}, - }, - { - Inputs: []int{}, - }, - }, - } - - expectedCompiled := gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - {{ - OutputWire: 1, - OutputInstance: 0, - InputInstance: 1, - }}, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{0}, - }}, - NbInstances: 2, - } - - expectedPermutations := gkrinfo.Permutations{ - SortedInstances: []int{1, 0}, - SortedWires: []int{1, 0}, - InstancesPermutation: []int{1, 0}, - WiresPermutation: []int{1, 0}, - } - - p, err := d.Compile(2) - assert.NoError(t, err) - assert.Equal(t, expectedPermutations, p) - assert.Equal(t, expectedCompiled, d) -} - -func TestCompile3Cycles(t *testing.T) { - var d = gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - nil, - { - { - OutputWire: 0, - OutputInstance: 2, - InputInstance: 0, - }, - { - OutputWire: 0, - OutputInstance: 1, - InputInstance: 2, - }, - }, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{2}, - }, - { - Inputs: []int{}, - }, - { - Inputs: []int{1}, - }, - }, - } - - expectedCompiled := gkrinfo.StoringInfo{ - Dependencies: [][]gkrinfo.InputDependency{ - {{ - OutputWire: 2, - OutputInstance: 0, - InputInstance: 1, - }, { - OutputWire: 2, - OutputInstance: 1, - InputInstance: 2, - }}, - - nil, - nil, - }, - Circuit: gkrinfo.Circuit{ - { - Inputs: []int{}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{0}, - NbUniqueOutputs: 1, - }, - { - Inputs: []int{1}, - NbUniqueOutputs: 0, - }, - }, - NbInstances: 3, // not allowed if we were actually performing gkr - } - - expectedPermutations := gkrinfo.Permutations{ - SortedInstances: []int{1, 2, 0}, - SortedWires: []int{1, 2, 0}, - InstancesPermutation: []int{2, 0, 1}, - WiresPermutation: []int{2, 0, 1}, - } - - p, err := d.Compile(3) - assert.NoError(t, err) - assert.Equal(t, expectedPermutations, p) - assert.Equal(t, expectedCompiled, d) -} diff --git a/std/gkrapi/example_test.go b/std/gkrapi/example_test.go index 29244e6aab..c3acd473c6 100644 --- a/std/gkrapi/example_test.go +++ b/std/gkrapi/example_test.go @@ -16,10 +16,9 @@ import ( func Example() { // This example computes the double of multiple BLS12-377 G1 points, which can be computed natively over BW6-761. - // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported - // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. + // The two curves form a "cycle", meaning the scalar field of one is the base field of the other. + // The implementation is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. // github.com/consensys/gnark-crypto/ecc/bls12-377 - const fsHashName = "MIMC" // register the gates: Doing so is not needed here because // the proof is being computed in the same session as the @@ -63,13 +62,12 @@ func Example() { } circuit := exampleCircuit{ - X: make([]frontend.Variable, nbInstances), - Y: make([]frontend.Variable, nbInstances), - Z: make([]frontend.Variable, nbInstances), - XOut: make([]frontend.Variable, nbInstances), - YOut: make([]frontend.Variable, nbInstances), - ZOut: make([]frontend.Variable, nbInstances), - fsHashName: fsHashName, + X: make([]frontend.Variable, nbInstances), + Y: make([]frontend.Variable, nbInstances), + Z: make([]frontend.Variable, nbInstances), + XOut: make([]frontend.Variable, nbInstances), + YOut: make([]frontend.Variable, nbInstances), + ZOut: make([]frontend.Variable, nbInstances), } assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) @@ -80,7 +78,6 @@ func Example() { type exampleCircuit struct { X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) - fsHashName string // name of the hash function used for Fiat-Shamir in the GKR verifier } func (c *exampleCircuit) Define(api frontend.API) error { @@ -90,21 +87,10 @@ func (c *exampleCircuit) Define(api frontend.API) error { gkrApi := gkrapi.New() - // create GKR circuit variables based on the given assignments - X, err := gkrApi.Import(c.X) - if err != nil { - return err - } - - Y, err := gkrApi.Import(c.Y) - if err != nil { - return err - } - - Z, err := gkrApi.Import(c.Z) - if err != nil { - return err - } + // create the GKR circuit + X := gkrApi.NewInput() + Y := gkrApi.NewInput() + Z := gkrApi.NewInput() XX := gkrApi.Gate(squareGate, X) // 405: XX.Square(&p.X) YY := gkrApi.Gate(squareGate, Y) // 406: YY.Square(&p.Y) @@ -116,45 +102,34 @@ func (c *exampleCircuit) Define(api frontend.API) error { // 414: M.Double(&XX).Add(&M, &XX) // Note (but don't explicitly compute) that M = 3XX - Z = gkrApi.Gate(zGate, Z, Y, YY, ZZ) // 415 - 418 - X = gkrApi.Gate(xGate, XX, S) // 419-422 - Y = gkrApi.Gate(yGate, S, X, XX, YYYY) // 423 - 426 + ZOut := gkrApi.Gate(zGate, Z, Y, YY, ZZ) // 415 - 418 + XOut := gkrApi.Gate(xGate, XX, S) // 419-422 + YOut := gkrApi.Gate(yGate, S, XOut, XX, YYYY) // 423 - 426 - // have to duplicate X for it to be considered an output variable - X = gkrApi.NamedGate(gkr.Identity, X) + // have to duplicate X for it to be considered an output variable; this is an implementation detail and will be fixed in the future [https://github.com/Consensys/gnark/issues/1452] + XOut = gkrApi.NamedGate(gkr.Identity, XOut) - // solve and prove the circuit - solution, err := gkrApi.Solve(api) + gkrCircuit, err := gkrApi.Compile(api, "MIMC") if err != nil { return err } - // check the output - - XOut := solution.Export(X) - YOut := solution.Export(Y) - ZOut := solution.Export(Z) - for i := range XOut { - api.AssertIsEqual(XOut[i], c.XOut[i]) - api.AssertIsEqual(YOut[i], c.YOut[i]) - api.AssertIsEqual(ZOut[i], c.ZOut[i]) - } - - challenges := make([]frontend.Variable, 0, len(c.X)*6) - challenges = append(challenges, XOut...) - challenges = append(challenges, YOut...) - challenges = append(challenges, ZOut...) - challenges = append(challenges, c.X...) - challenges = append(challenges, c.Y...) - challenges = append(challenges, c.Z...) - - challenge, err := api.(frontend.Committer).Commit(challenges...) - if err != nil { - return err + // add input and check output for correctness + instanceIn := make(map[gkr.Variable]frontend.Variable) + for i := range c.X { + instanceIn[X] = c.X[i] + instanceIn[Y] = c.Y[i] + instanceIn[Z] = c.Z[i] + + instanceOut, err := gkrCircuit.AddInstance(instanceIn) + if err != nil { + return err + } + api.AssertIsEqual(instanceOut[XOut], c.XOut[i]) + api.AssertIsEqual(instanceOut[YOut], c.YOut[i]) + api.AssertIsEqual(instanceOut[ZOut], c.ZOut[i]) } - - // verify the proof - return solution.Verify(c.fsHashName, challenge) + return nil } // custom gates diff --git a/std/gkrapi/hints.go b/std/gkrapi/hints.go deleted file mode 100644 index 577a4d6ed8..0000000000 --- a/std/gkrapi/hints.go +++ /dev/null @@ -1,137 +0,0 @@ -package gkrapi - -import ( - "errors" - "fmt" - "math/big" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - gcHash "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/constraint/solver/gkrgates" - bls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" - bls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" - bls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" - bls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" - bn254 "github.com/consensys/gnark/internal/gkr/bn254" - bw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" - bw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" - "github.com/consensys/gnark/internal/gkr/gkrinfo" - "github.com/consensys/gnark/internal/gkr/gkrtypes" - "github.com/consensys/gnark/internal/utils" -) - -var testEngineGkrSolvingData = make(map[string]any) - -func modKey(mod *big.Int) string { - return mod.Text(32) -} - -func SolveHintPlaceholder(gkrInfo gkrinfo.StoringInfo) solver.Hint { - return func(mod *big.Int, ins []*big.Int, outs []*big.Int) error { - - solvingInfo, err := gkrtypes.StoringToSolvingInfo(gkrInfo, gkrgates.Get) - if err != nil { - return err - } - - // TODO @Tabaie autogenerate this or decide not to - if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - var data bls12377.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls12377.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - var data bls12381.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls12381.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - var data bls24315.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls24315.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - var data bls24317.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bls24317.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - var data bn254.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bn254.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - var data bw6633.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bw6633.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - var data bw6761.SolvingData - testEngineGkrSolvingData[modKey(mod)] = &data - return bw6761.SolveHint(solvingInfo, &data)(mod, ins, outs) - } - - return errors.New("unsupported modulus") - } -} - -func ProveHintPlaceholder(hashName string) solver.Hint { - return func(mod *big.Int, ins, outs []*big.Int) error { - k := modKey(mod) - data, ok := testEngineGkrSolvingData[k] - if !ok { - return errors.New("solving data not found") - } - delete(testEngineGkrSolvingData, k) - - // TODO @Tabaie autogenerate this or decide not to - if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { - return bls12377.ProveHint(hashName, data.(*bls12377.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { - return bls12381.ProveHint(hashName, data.(*bls12381.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { - return bls24315.ProveHint(hashName, data.(*bls24315.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { - return bls24317.ProveHint(hashName, data.(*bls24317.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BN254.ScalarField()) == 0 { - return bn254.ProveHint(hashName, data.(*bn254.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { - return bw6633.ProveHint(hashName, data.(*bw6633.SolvingData))(mod, ins, outs) - } - if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { - return bw6761.ProveHint(hashName, data.(*bw6761.SolvingData))(mod, ins, outs) - } - - return errors.New("unsupported modulus") - } -} - -func CheckHashHint(hashName string) solver.Hint { - return func(mod *big.Int, ins, outs []*big.Int) error { - if len(ins) != 2 || len(outs) != 1 { - return errors.New("invalid number of inputs/outputs") - } - - toHash := ins[0].Bytes() - expectedHash := ins[1] - - hsh := gcHash.NewHash(fmt.Sprintf("%s_%s", hashName, strings.ToUpper(utils.FieldToCurve(mod).String()))) - hsh.Write(toHash) - hashed := hsh.Sum(nil) - - if hashed := new(big.Int).SetBytes(hashed); hashed.Cmp(expectedHash) != 0 { - return fmt.Errorf("hash mismatch: expected %s, got %s", expectedHash.String(), hashed.String()) - } - - outs[0].SetBytes(hashed) - - return nil - } -} diff --git a/std/gkrapi/testing.go b/std/gkrapi/testing.go deleted file mode 100644 index 17163c0b5a..0000000000 --- a/std/gkrapi/testing.go +++ /dev/null @@ -1,120 +0,0 @@ -package gkrapi - -import ( - "errors" - "fmt" - "sync" - - "github.com/consensys/gnark/constraint/solver/gkrgates" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/utils" - "github.com/consensys/gnark/std/gkrapi/gkr" - stdHash "github.com/consensys/gnark/std/hash" -) - -type solveInTestEngineSettings struct { - hashName string -} - -type SolveInTestEngineOption func(*solveInTestEngineSettings) - -func WithHashName(name string) SolveInTestEngineOption { - return func(s *solveInTestEngineSettings) { - s.hashName = name - } -} - -// SolveInTestEngine solves the defined circuit directly inside the SNARK circuit. This means that the method does not compute the GKR proof of the circuit and does not embed the GKR proof verifier inside a SNARK. -// The output is the values of all variables, across all instances; i.e. indexed variable-first, instance-second. -// This method only works under the test engine and should only be called to debug a GKR circuit, as the GKR prover's errors can be obscure. -func (api *API) SolveInTestEngine(parentApi frontend.API, options ...SolveInTestEngineOption) [][]frontend.Variable { - gateVer, err := gkrgates.NewGateVerifier(utils.FieldToCurve(parentApi.Compiler().Field())) - if err != nil { - panic(err) - } - - var s solveInTestEngineSettings - for _, o := range options { - o(&s) - } - if s.hashName != "" { - // hash something and make sure it gives the same answer both on prover and verifier sides - // TODO @Tabaie If indeed cheap, move this feature to Verify so that it is always run - h, err := stdHash.GetFieldHasher(s.hashName, parentApi) - if err != nil { - panic(err) - } - nbBytes := (parentApi.Compiler().FieldBitLen() + 7) / 8 - toHash := frontend.Variable(0) - for i := range nbBytes { - toHash = parentApi.Add(parentApi.Mul(toHash, 256), i%256) - } - h.Reset() - h.Write(toHash) - hashed := h.Sum() - - hintOut, err := parentApi.Compiler().NewHint(CheckHashHint(s.hashName), 1, toHash, hashed) - if err != nil { - panic(err) - } - parentApi.AssertIsEqual(hintOut[0], hashed) // the hint already checks this - } - - res := make([][]frontend.Variable, len(api.toStore.Circuit)) - var verifiedGates sync.Map - for i, w := range api.toStore.Circuit { - res[i] = make([]frontend.Variable, api.nbInstances()) - copy(res[i], api.assignments[i]) - if len(w.Inputs) == 0 { - continue - } - } - for instanceI := range api.nbInstances() { - for wireI, w := range api.toStore.Circuit { - deps := api.toStore.Dependencies[wireI] - if len(deps) != 0 && len(w.Inputs) != 0 { - panic(fmt.Errorf("non-input wire %d should not have dependencies", wireI)) - } - for _, dep := range deps { - if dep.InputInstance == instanceI { - if dep.OutputInstance >= instanceI { - panic(fmt.Errorf("out of order dependency not yet supported in SolveInTestEngine; (wire %d, instance %d) depends on (wire %d, instance %d)", wireI, instanceI, dep.OutputWire, dep.OutputInstance)) - } - if res[wireI][instanceI] != nil { - panic(fmt.Errorf("dependency (wire %d, instance %d) <- (wire %d, instance %d) attempting to override existing value assignment", wireI, instanceI, dep.OutputWire, dep.OutputInstance)) - } - res[wireI][instanceI] = res[dep.OutputWire][dep.OutputInstance] - } - } - - if res[wireI][instanceI] == nil { // no assignment or dependency - if len(w.Inputs) == 0 { - panic(fmt.Errorf("input wire %d, instance %d has no dependency or explicit assignment", wireI, instanceI)) - } - ins := make([]frontend.Variable, len(w.Inputs)) - for i, in := range w.Inputs { - ins[i] = res[in][instanceI] - } - gate := gkrgates.Get(gkr.GateName(w.Gate)) - if gate == nil && !w.IsInput() { - panic(fmt.Errorf("gate %s not found", w.Gate)) - } - if _, ok := verifiedGates.Load(w.Gate); !ok { - verifiedGates.Store(w.Gate, struct{}{}) - - err = errors.Join( - gateVer.VerifyDegree(gate), - gateVer.VerifySolvability(gate), - ) - if err != nil { - panic(fmt.Errorf("gate %s: %w", w.Gate, err)) - } - } - if gate != nil { - res[wireI][instanceI] = gate.Evaluate(parentApi, ins...) - } - } - } - } - return res -} diff --git a/std/lookup/logderivlookup/logderivlookup.go b/std/lookup/logderivlookup/logderivlookup.go index 63f2bc694d..dbeb042762 100644 --- a/std/lookup/logderivlookup/logderivlookup.go +++ b/std/lookup/logderivlookup/logderivlookup.go @@ -1,4 +1,4 @@ -// Package logderiv implements append-only lookups using log-derivative +// Package logderivlookup implements append-only lookups using log-derivative // argument. // // The lookup is based on log-derivative argument as described in [logderivarg]. diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go index 218d252eba..bf72ab68e7 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -1,23 +1,16 @@ package gkr_poseidon2 import ( - "errors" "fmt" - "math/big" "sync" - "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/constraint/solver/gkrgates" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" - "github.com/consensys/gnark/std/hash" - _ "github.com/consensys/gnark/std/hash/mimc" // to ensure mimc is registered "github.com/consensys/gnark-crypto/ecc" - frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + "github.com/consensys/gnark/frontend" ) // extKeyGate applies the external matrix mul, then adds the round key @@ -45,7 +38,7 @@ func pow4Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { // pow4TimesGate computes a, b -> a⁴ * b func pow4TimesGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { - panic("expected 1 input") + panic("expected 2 input") } y := api.Mul(x[0], x[0]) y = api.Mul(y, y) @@ -115,63 +108,72 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrCompressions struct { - api frontend.API - ins1 []frontend.Variable - ins2 []frontend.Variable - outs []frontend.Variable +type GkrCompressor struct { + api frontend.API + gkrCircuit *gkrapi.Circuit + in1, in2, out gkr.Variable } -// NewGkrCompressions returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. -// Note that the solver will need the function RegisterGkrSolverOptions to be called with the desired curves -func NewGkrCompressions(api frontend.API) *GkrCompressions { - res := GkrCompressions{ - api: api, +// Note that the solver will need the function RegisterGkrGates to be called with the desired curves +func NewGkrCompressor(api frontend.API) *GkrCompressor { + if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { + panic("currently only BLS12-377 is supported") + } + gkrApi, in1, in2, out, err := defineCircuitBls12377() + if err != nil { + panic(fmt.Sprintf("failed to define GKR circuit: %v", err)) + } + gkrCircuit, err := gkrApi.Compile(api, "MIMC") + if err != nil { + panic(fmt.Sprintf("failed to compile GKR circuit: %v", err)) + } + return &GkrCompressor{ + api: api, + gkrCircuit: gkrCircuit, + in1: in1, + in2: in2, + out: out, } - api.Compiler().Defer(res.finalize) - return &res } -func (p *GkrCompressions) Compress(a, b frontend.Variable) frontend.Variable { - s, err := p.api.Compiler().NewHint(permuteHint, 1, a, b) +func (p *GkrCompressor) Compress(a, b frontend.Variable) frontend.Variable { + outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) } - p.ins1 = append(p.ins1, a) - p.ins2 = append(p.ins2, b) - p.outs = append(p.outs, s[0]) - return s[0] + + return outs[p.out] } -// defineCircuit defines the GKR circuit for the Poseidon2 permutation over BLS12-377 +// defineCircuitBls12377 defines the GKR circuit for the Poseidon2 permutation over BLS12-377 // insLeft and insRight are the inputs to the permutation // they must be padded to a power of 2 -func defineCircuit(insLeft, insRight []frontend.Variable) (*gkrapi.API, gkr.Variable, error) { +func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, err error) { // variable indexes const ( xI = iota yI ) + if err = registerGatesBls12377(); err != nil { + return + } + // poseidon2 parameters gateNamer := newRoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 - gkrApi := gkrapi.New() + gkrApi = gkrapi.New() - x, err := gkrApi.Import(insLeft) - if err != nil { - return nil, -1, err - } - y, err := gkrApi.Import(insRight) - y0 := y // save to feed forward at the end - if err != nil { - return nil, -1, err - } + x := gkrApi.NewInput() + y := gkrApi.NewInput() + + in1, in2 = x, y // save to feed forward at the end // *** helper functions to register and apply gates *** @@ -240,80 +242,9 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkrapi.API, gkr.Vari } // apply the external matrix one last time to obtain the final value of y - y = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, y0) - - return gkrApi, y, nil -} - -func (p *GkrCompressions) finalize(api frontend.API) error { - if p.api != api { - panic("unexpected API") - } - - // register gates - registerGkrSolverOptions(api) - - // pad instances into a power of 2 - // TODO @Tabaie the GKR API to do this automatically? - ins1Padded := make([]frontend.Variable, ecc.NextPowerOfTwo(uint64(len(p.ins1)))) - ins2Padded := make([]frontend.Variable, len(ins1Padded)) - copy(ins1Padded, p.ins1) - copy(ins2Padded, p.ins2) - for i := len(p.ins1); i < len(ins1Padded); i++ { - ins1Padded[i] = 0 - ins2Padded[i] = 0 - } + out = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, in2) - gkrApi, y, err := defineCircuit(ins1Padded, ins2Padded) - if err != nil { - return err - } - - // connect to output - // TODO can we save 1 constraint per instance by giving the desired outputs to the gkr api? - solution, err := gkrApi.Solve(api) - if err != nil { - return err - } - yVals := solution.Export(y) - for i := range p.outs { - api.AssertIsEqual(yVals[i], p.outs[i]) - } - - // verify GKR proof - allVals := make([]frontend.Variable, 0, 3*len(p.ins1)) - allVals = append(allVals, p.ins1...) - allVals = append(allVals, p.ins2...) - allVals = append(allVals, p.outs...) - challenge, err := p.api.(frontend.Committer).Commit(allVals...) - if err != nil { - return err - } - return solution.Verify(hash.MIMC.String(), challenge) -} - -// registerGkrSolverOptions is a wrapper for RegisterGkrSolverOptions -// that performs the registration for the curve associated with api. -func registerGkrSolverOptions(api frontend.API) { - RegisterGkrSolverOptions(utils.FieldToCurve(api.Compiler().Field())) -} - -func permuteHint(m *big.Int, ins, outs []*big.Int) error { - if m.Cmp(ecc.BLS12_377.ScalarField()) != 0 { - return errors.New("only bls12-377 supported") - } - if len(ins) != 2 || len(outs) != 1 { - return errors.New("expected 2 inputs and 1 output") - } - var x [2]frBls12377.Element - x[0].SetBigInt(ins[0]) - x[1].SetBigInt(ins[1]) - y0 := x[1] - - err := bls12377Permutation().Permutation(x[:]) - x[1].Add(&x[1], &y0) // feed forward - x[1].BigInt(outs[0]) - return err + return } var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { @@ -321,16 +252,15 @@ var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { return poseidon2Bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto }) -// RegisterGkrSolverOptions registers the GKR gates corresponding to the given curves for the solver -func RegisterGkrSolverOptions(curves ...ecc.ID) { +// RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver +func RegisterGkrGates(curves ...ecc.ID) { if len(curves) == 0 { panic("expected at least one curve") } - solver.RegisterHint(permuteHint) for _, curve := range curves { switch curve { case ecc.BLS12_377: - if err := registerGkrGatesBls12377(); err != nil { + if err := registerGatesBls12377(); err != nil { panic(err) } default: @@ -339,7 +269,7 @@ func RegisterGkrSolverOptions(curves ...ecc.ID) { } } -func registerGkrGatesBls12377() error { +func registerGatesBls12377() error { const ( x = iota y @@ -349,29 +279,31 @@ func registerGkrGatesBls12377() error { halfRf := p.NbFullRounds / 2 gateNames := newRoundGateNamer(p) - if err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar()); err != nil { + if err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar()); err != nil { + if err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar()); err != nil { + if err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar()); err != nil { + if err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } - if err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0)); err != nil { + if err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { return err } extKeySBox := func(round int, varIndex int) error { - return gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round))) + err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } intKeySBox2 := func(round int) error { - return gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round))) + err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } fullRound := func(i int) error { @@ -415,7 +347,8 @@ func registerGkrGatesBls12377() error { } } - return gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds))) + err := gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds)), gkrgates.WithCurves(ecc.BLS12_377)) + return err } type roundGateNamer string diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 1503054a59..0a230c4381 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -8,12 +8,12 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" + _ "github.com/consensys/gnark/std/hash/all" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" ) -func TestGkrCompression(t *testing.T) { - const n = 2 +func gkrCompressionCircuits(t require.TestingT, n int) (circuit, assignment testGkrCompressionCircuit) { var k int64 ins := make([][2]frontend.Variable, n) outs := make([]frontend.Variable, n) @@ -32,24 +32,29 @@ func TestGkrCompression(t *testing.T) { k += 2 } - circuit := testGkrPermutationCircuit{ - Ins: ins, - Outs: outs, - } + return testGkrCompressionCircuit{ + Ins: make([][2]frontend.Variable, len(ins)), + Outs: make([]frontend.Variable, len(outs)), + }, testGkrCompressionCircuit{ + Ins: ins, + Outs: outs, + } +} - RegisterGkrSolverOptions(ecc.BLS12_377) +func TestGkrCompression(t *testing.T) { + circuit, assignment := gkrCompressionCircuits(t, 2) - test.NewAssert(t).CheckCircuit(&testGkrPermutationCircuit{Ins: make([][2]frontend.Variable, len(ins)), Outs: make([]frontend.Variable, len(outs))}, test.WithValidAssignment(&circuit), test.WithCurves(ecc.BLS12_377)) + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) } -type testGkrPermutationCircuit struct { +type testGkrCompressionCircuit struct { Ins [][2]frontend.Variable Outs []frontend.Variable } -func (c *testGkrPermutationCircuit) Define(api frontend.API) error { +func (c *testGkrCompressionCircuit) Define(api frontend.API) error { - pos2 := NewGkrCompressions(api) + pos2 := NewGkrCompressor(api) api.AssertIsEqual(len(c.Ins), len(c.Outs)) for i := range c.Ins { api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) @@ -58,9 +63,9 @@ func (c *testGkrPermutationCircuit) Define(api frontend.API) error { return nil } -func TestGkrPermutationCompiles(t *testing.T) { +func TestGkrCompressionCompiles(t *testing.T) { // just measure the number of constraints - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrPermutationCircuit{ + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrCompressionCircuit{ Ins: make([][2]frontend.Variable, 52000), Outs: make([]frontend.Variable, 52000), })