Skip to content

Commit 33adbaf

Browse files
liyue201ivokub
andauthored
Feat: implement FixedLengthSum function for sha3 (#1379)
Co-authored-by: Ivo Kubjas <ivo.kubjas@consensys.net>
1 parent 095d87e commit 33adbaf

File tree

4 files changed

+153
-7
lines changed

4 files changed

+153
-7
lines changed

std/hash/sha3/hashes.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ import (
99
// New256 creates a new SHA3-256 hash.
1010
// Its generic security strength is 256 bits against preimage attacks,
1111
// and 128 bits against collision attacks.
12-
func New256(api frontend.API) (hash.BinaryHasher, error) {
12+
func New256(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
1313
uapi, err := uints.New[uints.U64](api)
1414
if err != nil {
1515
return nil, err
1616
}
1717
return &digest{
18+
api: api,
1819
uapi: uapi,
1920
state: newState(),
2021
dsbyte: 0x06,
@@ -26,12 +27,13 @@ func New256(api frontend.API) (hash.BinaryHasher, error) {
2627
// New384 creates a new SHA3-384 hash.
2728
// Its generic security strength is 384 bits against preimage attacks,
2829
// and 192 bits against collision attacks.
29-
func New384(api frontend.API) (hash.BinaryHasher, error) {
30+
func New384(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
3031
uapi, err := uints.New[uints.U64](api)
3132
if err != nil {
3233
return nil, err
3334
}
3435
return &digest{
36+
api: api,
3537
uapi: uapi,
3638
state: newState(),
3739
dsbyte: 0x06,
@@ -43,12 +45,13 @@ func New384(api frontend.API) (hash.BinaryHasher, error) {
4345
// New512 creates a new SHA3-512 hash.
4446
// Its generic security strength is 512 bits against preimage attacks,
4547
// and 256 bits against collision attacks.
46-
func New512(api frontend.API) (hash.BinaryHasher, error) {
48+
func New512(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
4749
uapi, err := uints.New[uints.U64](api)
4850
if err != nil {
4951
return nil, err
5052
}
5153
return &digest{
54+
api: api,
5255
uapi: uapi,
5356
state: newState(),
5457
dsbyte: 0x06,
@@ -61,12 +64,13 @@ func New512(api frontend.API) (hash.BinaryHasher, error) {
6164
//
6265
// Only use this function if you require compatibility with an existing cryptosystem
6366
// that uses non-standard padding. All other users should use New256 instead.
64-
func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) {
67+
func NewLegacyKeccak256(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
6568
uapi, err := uints.New[uints.U64](api)
6669
if err != nil {
6770
return nil, err
6871
}
6972
return &digest{
73+
api: api,
7074
uapi: uapi,
7175
state: newState(),
7276
dsbyte: 0x01,
@@ -79,12 +83,13 @@ func NewLegacyKeccak256(api frontend.API) (hash.BinaryHasher, error) {
7983
//
8084
// Only use this function if you require compatibility with an existing cryptosystem
8185
// that uses non-standard padding. All other users should use New512 instead.
82-
func NewLegacyKeccak512(api frontend.API) (hash.BinaryHasher, error) {
86+
func NewLegacyKeccak512(api frontend.API) (hash.BinaryFixedLengthHasher, error) {
8387
uapi, err := uints.New[uints.U64](api)
8488
if err != nil {
8589
return nil, err
8690
}
8791
return &digest{
92+
api: api,
8893
uapi: uapi,
8994
state: newState(),
9095
dsbyte: 0x01,

std/hash/sha3/sha3.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package sha3
22

33
import (
4+
"math/big"
5+
6+
"github.com/consensys/gnark/frontend"
7+
"github.com/consensys/gnark/std/math/cmp"
48
"github.com/consensys/gnark/std/math/uints"
59
"github.com/consensys/gnark/std/permutation/keccakf"
610
)
711

812
type digest struct {
13+
api frontend.API
914
uapi *uints.BinaryField[uints.U64]
1015
state [25]uints.U64 // 1600 bits state: 25 x 64
1116
in []uints.U8 // input to be digested
@@ -27,11 +32,20 @@ func (d *digest) Reset() {
2732

2833
func (d *digest) Sum() []uints.U8 {
2934
padded := d.padding()
35+
3036
blocks := d.composeBlocks(padded)
3137
d.absorbing(blocks)
3238
return d.squeezeBlocks()
3339
}
3440

41+
func (d *digest) FixedLengthSum(length frontend.Variable) []uints.U8 {
42+
padded, numberOfBlocks := d.paddingFixedWidth(length)
43+
44+
blocks := d.composeBlocks(padded)
45+
d.absorbingFixedWidth(blocks, numberOfBlocks)
46+
return d.squeezeBlocks()
47+
}
48+
3549
func (d *digest) padding() []uints.U8 {
3650
padded := make([]uints.U8, len(d.in))
3751
copy(padded[:], d.in[:])
@@ -51,6 +65,34 @@ func (d *digest) padding() []uints.U8 {
5165
return padded
5266
}
5367

68+
func (d *digest) paddingFixedWidth(length frontend.Variable) (padded []uints.U8, numberOfBlocks frontend.Variable) {
69+
numberOfBlocks = frontend.Variable(0)
70+
padded = make([]uints.U8, len(d.in))
71+
copy(padded[:], d.in[:])
72+
padded = append(padded, uints.NewU8Array(make([]uint8, d.rate))...)
73+
74+
for i := 0; i <= len(padded)-d.rate; i++ {
75+
reachEnd := cmp.IsEqual(d.api, i, length)
76+
switch q := d.rate - ((i) % d.rate); q {
77+
case 1:
78+
padded[i].Val = d.api.Select(reachEnd, d.dsbyte^0x80, padded[i].Val)
79+
numberOfBlocks = d.api.Select(reachEnd, (i+1)/d.rate, numberOfBlocks)
80+
case 2:
81+
padded[i].Val = d.api.Select(reachEnd, d.dsbyte, padded[i].Val)
82+
padded[i+1].Val = d.api.Select(reachEnd, 0x80, padded[i+1].Val)
83+
numberOfBlocks = d.api.Select(reachEnd, (i+2)/d.rate, numberOfBlocks)
84+
default:
85+
padded[i].Val = d.api.Select(reachEnd, d.dsbyte, padded[i].Val)
86+
for j := 0; j < q-1; j++ {
87+
padded[i+1+j].Val = d.api.Select(reachEnd, 0, padded[i+1+j].Val)
88+
}
89+
padded[i+q-1].Val = d.api.Select(reachEnd, 0x80, padded[i+q-1].Val)
90+
numberOfBlocks = d.api.Select(reachEnd, (i+q)/d.rate, numberOfBlocks)
91+
}
92+
}
93+
return padded, numberOfBlocks
94+
}
95+
5496
func (d *digest) composeBlocks(padded []uints.U8) [][]uints.U64 {
5597
blocks := make([][]uints.U64, len(padded)/d.rate)
5698

@@ -76,6 +118,30 @@ func (d *digest) absorbing(blocks [][]uints.U64) {
76118
}
77119
}
78120

121+
func (d *digest) absorbingFixedWidth(blocks [][]uints.U64, nbBlocks frontend.Variable) {
122+
var state [25]uints.U64
123+
var resultState [25]uints.U64
124+
copy(resultState[:], d.state[:])
125+
copy(state[:], d.state[:])
126+
127+
comparator := cmp.NewBoundedComparator(d.api, big.NewInt(int64(len(blocks))), false)
128+
129+
for i, block := range blocks {
130+
for j := range block {
131+
state[j] = d.uapi.Xor(state[j], block[j])
132+
}
133+
state = keccakf.Permute(d.uapi, state)
134+
isInRange := comparator.IsLess(i, nbBlocks)
135+
// only select blocks that are in range
136+
for j := 0; j < 25; j++ {
137+
for k := 0; k < 8; k++ {
138+
resultState[j][k].Val = d.api.Select(isInRange, state[j][k].Val, resultState[j][k].Val)
139+
}
140+
}
141+
}
142+
copy(d.state[:], resultState[:])
143+
}
144+
79145
func (d *digest) squeezeBlocks() (result []uints.U8) {
80146
for i := 0; i < d.outputLen/8; i++ {
81147
result = append(result, d.uapi.UnpackLSB(d.state[i])...)

std/hash/sha3/sha3_test.go

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ import (
66
"hash"
77
"testing"
88

9+
"golang.org/x/crypto/sha3"
10+
911
"github.com/consensys/gnark-crypto/ecc"
1012
"github.com/consensys/gnark/frontend"
1113
zkhash "github.com/consensys/gnark/std/hash"
1214
"github.com/consensys/gnark/std/math/uints"
1315
"github.com/consensys/gnark/test"
14-
"golang.org/x/crypto/sha3"
1516
)
1617

1718
type testCase struct {
18-
zk func(api frontend.API) (zkhash.BinaryHasher, error)
19+
zk func(api frontend.API) (zkhash.BinaryFixedLengthHasher, error)
1920
native func() hash.Hash
2021
}
2122

@@ -88,3 +89,70 @@ func TestSHA3(t *testing.T) {
8889
}, name)
8990
}
9091
}
92+
93+
type sha3FixedLengthSumCircuit struct {
94+
In []uints.U8
95+
Expected []uints.U8
96+
Length frontend.Variable
97+
hasher string
98+
}
99+
100+
func (c *sha3FixedLengthSumCircuit) Define(api frontend.API) error {
101+
newHasher, ok := testCases[c.hasher]
102+
if !ok {
103+
return fmt.Errorf("hash function unknown: %s", c.hasher)
104+
}
105+
h, err := newHasher.zk(api)
106+
if err != nil {
107+
return err
108+
}
109+
uapi, err := uints.New[uints.U64](api)
110+
if err != nil {
111+
return err
112+
}
113+
h.Write(c.In)
114+
res := h.FixedLengthSum(c.Length)
115+
116+
for i := range c.Expected {
117+
uapi.ByteAssertEq(c.Expected[i], res[i])
118+
}
119+
return nil
120+
}
121+
122+
func TestSHA3FixedLengthSum(t *testing.T) {
123+
assert := test.NewAssert(t)
124+
in := make([]byte, 310)
125+
_, err := rand.Reader.Read(in)
126+
assert.NoError(err)
127+
128+
for name := range testCases {
129+
assert.Run(func(assert *test.Assert) {
130+
name := name
131+
strategy := testCases[name]
132+
for _, length := range []int{0, 1, 31, 32, 33, 135, 136, 137, len(in)} {
133+
assert.Run(func(assert *test.Assert) {
134+
h := strategy.native()
135+
h.Write(in[:length])
136+
expected := h.Sum(nil)
137+
138+
circuit := &sha3FixedLengthSumCircuit{
139+
In: make([]uints.U8, len(in)),
140+
Expected: make([]uints.U8, len(expected)),
141+
Length: 0,
142+
hasher: name,
143+
}
144+
145+
witness := &sha3FixedLengthSumCircuit{
146+
In: uints.NewU8Array(in),
147+
Expected: uints.NewU8Array(expected),
148+
Length: length,
149+
}
150+
151+
if err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()); err != nil {
152+
t.Fatalf("%s: %s", name, err)
153+
}
154+
}, fmt.Sprintf("length=%d", length))
155+
}
156+
}, fmt.Sprintf("hash=%s", name))
157+
}
158+
}

std/math/cmp/generic.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ import (
77
"math/big"
88
)
99

10+
// IsEqual returns 1 if a = b, and returns 0 if a != b. a and b should be
11+
// integers in range [0, P-1], where P is the order of the underlying field used
12+
// by the proof system.
13+
func IsEqual(api frontend.API, a, b frontend.Variable) frontend.Variable {
14+
return api.IsZero(api.Sub(a, b))
15+
}
16+
1017
// IsLess returns 1 if a < b, and returns 0 if a >= b. a and b should be
1118
// integers in range [0, P-1], where P is the order of the underlying field used
1219
// by the proof system.

0 commit comments

Comments
 (0)