Skip to content

Commit 6bfec77

Browse files
authored
Reduce allocations in plonky2 Poseidon2 hash.Hash impl (#16)
1 parent bcf4e3c commit 6bfec77

File tree

2 files changed

+141
-32
lines changed

2 files changed

+141
-32
lines changed

hash/poseidon2_goldilocks_plonky2/poseidon2.go

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,34 @@ func HashNToMNoPad(input []g.GoldilocksField, numOutputs int) []g.GoldilocksFiel
9797
}
9898
}
9999

100+
func HashNToMNoPadBytes(input []byte, numOutputs int) []g.GoldilocksField {
101+
if len(input)%g.Bytes != 0 {
102+
panic("input length should be multiple of 8")
103+
}
104+
105+
inputLen := len(input) / g.Bytes
106+
107+
var perm [WIDTH]g.GoldilocksField
108+
for i := 0; i < inputLen; i += RATE {
109+
for j := 0; j < RATE && i+j < inputLen; j++ {
110+
index := (i + j) * g.Bytes
111+
perm[j] = g.FromCanonicalLittleEndianBytesF(input[index : index+g.Bytes])
112+
}
113+
Permute(&perm)
114+
}
115+
116+
outputs := make([]g.GoldilocksField, 0, numOutputs)
117+
for {
118+
for i := 0; i < RATE; i++ {
119+
outputs = append(outputs, perm[i])
120+
if len(outputs) == numOutputs {
121+
return outputs
122+
}
123+
}
124+
Permute(&perm)
125+
}
126+
}
127+
100128
func Permute(input *[WIDTH]g.GoldilocksField) {
101129
externalLinearLayer(input)
102130
fullRounds(input, 0)
@@ -185,37 +213,45 @@ func sboxP(index int, state *[WIDTH]g.GoldilocksField) {
185213
state[index] = g.MulF(tmpSixth, tmp)
186214
}
187215

188-
const BlockSize = g.Bytes // BlockSize size that poseidon consumes
216+
const BlockSize = g.Bytes * WIDTH // BlockSize size that poseidon consumes
189217

190218
type digest struct {
191-
data []g.GoldilocksField
219+
data []byte
220+
len int
192221
}
193222

194223
func NewPoseidon2() hash.Hash {
195224
d := new(digest)
196-
d.Reset()
197225
return d
198226
}
199227

200228
// Reset resets the Hash to its initial state.
201229
func (d *digest) Reset() {
202-
d.data = nil
230+
d.data = d.data[:0]
231+
d.len = 0
203232
}
204233

205234
// Get element by element.
206235
func (d *digest) Write(p []byte) (n int, err error) {
207-
if len(p)%g.Bytes != 0 {
208-
return 0, fmt.Errorf("input bytes len should be multiple of 8 but is %d", len(p))
209-
}
236+
d.data = append(d.data, p...)
237+
d.len += len(p)
210238

211-
gArr := make([]g.GoldilocksField, len(p)/g.Bytes)
212-
for i := 0; i < len(p); i += g.Bytes {
213-
gArr[i/g.Bytes] = g.FromCanonicalLittleEndianBytesF(p[i : i+g.Bytes])
214-
}
215-
d.data = append(d.data, gArr...)
216239
return len(p), nil
217240
}
218241

242+
// Sum appends the current hash to b and returns the resulting slice.
243+
// It does not change the underlying hash state.
244+
func (d *digest) Sum(b []byte) []byte {
245+
h := HashNToMNoPadBytes(d.data, 4)
246+
d.Reset()
247+
248+
for _, elem := range h {
249+
b = append(b, g.ToLittleEndianBytesF(elem)...)
250+
}
251+
252+
return b
253+
}
254+
219255
func (d *digest) Size() int {
220256
return BlockSize
221257
}
@@ -224,11 +260,3 @@ func (d *digest) Size() int {
224260
func (d *digest) BlockSize() int {
225261
return BlockSize
226262
}
227-
228-
// Sum appends the current hash to b and returns the resulting slice.
229-
// It does not change the underlying hash state.
230-
func (d *digest) Sum(b []byte) []byte {
231-
b = append(b, HashNToHashNoPad(d.data).ToLittleEndianBytes()...)
232-
d.data = nil
233-
return b
234-
}

hash/poseidon2_test.go

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"math/rand/v2"
88
"os"
9+
"runtime"
910
"strconv"
1011
"strings"
1112
"testing"
@@ -94,21 +95,60 @@ func TestPoseidon2Bench(t *testing.T) {
9495
t.FailNow()
9596
}
9697

97-
results := make([]g.GoldilocksField, 0, 4*len(inputs))
98-
start := time.Now()
99-
for _, input := range inputs {
100-
res := poseidon2_plonky2.HashNToHashNoPad(input)
101-
results = append(results, res[:]...)
98+
for i := 0; i < 10; i++ {
99+
PrintMemUsage()
100+
101+
results := make([]g.GoldilocksField, 0, 4*len(inputs))
102+
start := time.Now()
103+
for _, input := range inputs {
104+
res := poseidon2_plonky2.HashNToHashNoPad(input)
105+
results = append(results, res[:]...)
106+
}
107+
duration := time.Since(start)
108+
t.Logf("HashNToHashNoPad plonky2 took %s for %d inputs", duration, totalInputs)
109+
110+
sha2 := sha256.New()
111+
for _, res := range results {
112+
sha2.Write(g.ToLittleEndianBytesF(res))
113+
}
114+
t.Logf("Hash: %x\n", sha2.Sum(nil))
102115
}
103-
duration := time.Since(start)
104-
t.Logf("HashNToHashNoPad plonky2 took %s for %d inputs", duration, totalInputs)
116+
}
105117

106-
sha2 := sha256.New()
107-
for _, res := range results {
108-
sha2.Write(g.ToLittleEndianBytesF(res))
118+
func TestPoseidon2HasherBench(t *testing.T) {
119+
inputs, err := readBenchInputsBytes("bench_vector")
120+
totalInputs := len(inputs)
121+
if err != nil {
122+
t.Logf("Error: %v\n", err)
123+
t.FailNow()
109124
}
110-
hash := sha2.Sum(nil)
111-
t.Logf("Hash: %x\n", hash)
125+
126+
hasher := poseidon2_plonky2.NewPoseidon2()
127+
start1 := time.Now()
128+
129+
for i := 0; i < 10; i++ {
130+
PrintMemUsage()
131+
132+
results := make([]byte, 0, 4*8*len(inputs))
133+
start := time.Now()
134+
for _, input := range inputs {
135+
for _, b := range input {
136+
hasher.Write(b)
137+
}
138+
res := hasher.Sum(nil)
139+
hasher.Reset()
140+
results = append(results, res...)
141+
}
142+
duration := time.Since(start)
143+
t.Logf("Hasher plonky2 took %s for %d inputs", duration, totalInputs)
144+
145+
sha2 := sha256.New()
146+
sha2.Write(results)
147+
t.Logf("Hash: %x\n", sha2.Sum(nil))
148+
}
149+
150+
duration := time.Since(start1)
151+
t.Logf("===> Hasher plonky2 took %s", duration)
112152
}
113153

114154
func TestPoseidon2BenchOld(t *testing.T) {
@@ -167,6 +207,37 @@ func readBenchInputs(filename string) ([][]g.GoldilocksField, error) {
167207
return inputs, nil
168208
}
169209

210+
func readBenchInputsBytes(filename string) ([][][]byte, error) {
211+
file, err := os.Open(filename)
212+
if err != nil {
213+
return nil, fmt.Errorf("failed to open file: %v", err)
214+
}
215+
defer file.Close()
216+
217+
scanner := bufio.NewScanner(file)
218+
var inputs [][][]byte
219+
220+
for scanner.Scan() {
221+
line := scanner.Text()
222+
strVals := strings.Split(line, ",")
223+
var input [][]byte
224+
for _, strVal := range strVals {
225+
val, err := strconv.ParseUint(strVal, 10, 64)
226+
if err != nil {
227+
return nil, fmt.Errorf("failed to parse uint64: %v", err)
228+
}
229+
input = append(input, g.ToLittleEndianBytesF(g.GoldilocksField(val)))
230+
}
231+
inputs = append(inputs, input)
232+
}
233+
234+
if err := scanner.Err(); err != nil {
235+
return nil, fmt.Errorf("failed to read file: %v", err)
236+
}
237+
238+
return inputs, nil
239+
}
240+
170241
func readBenchInputsOld(filename string) ([][]g.Element, error) {
171242
file, err := os.Open(filename)
172243
if err != nil {
@@ -197,3 +268,13 @@ func readBenchInputsOld(filename string) ([][]g.Element, error) {
197268

198269
return inputs, nil
199270
}
271+
272+
func PrintMemUsage() {
273+
var m runtime.MemStats
274+
runtime.ReadMemStats(&m)
275+
// For info on each, see: https://golang.org/pkg/runtime/#MemStats
276+
fmt.Printf("Alloc = %v Bytes", m.Alloc)
277+
fmt.Printf("\tTotalAlloc = %v Bytes", m.TotalAlloc)
278+
fmt.Printf("\tSys = %v Bytes", m.Sys)
279+
fmt.Printf("\tNumGC = %v\n", m.NumGC)
280+
}

0 commit comments

Comments
 (0)