Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions hash/poseidon2_goldilocks_plonky2/poseidon2.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,34 @@ func HashNToMNoPad(input []g.GoldilocksField, numOutputs int) []g.GoldilocksFiel
}
}

func HashNToMNoPadBytes(input []byte, numOutputs int) []g.GoldilocksField {
if len(input)%g.Bytes != 0 {
panic("input length should be multiple of 8")
}

inputLen := len(input) / g.Bytes

var perm [WIDTH]g.GoldilocksField
for i := 0; i < inputLen; i += RATE {
for j := 0; j < RATE && i+j < inputLen; j++ {
index := (i + j) * g.Bytes
perm[j] = g.FromCanonicalLittleEndianBytesF(input[index : index+g.Bytes])
}
Permute(&perm)
}

outputs := make([]g.GoldilocksField, 0, numOutputs)
for {
for i := 0; i < RATE; i++ {
outputs = append(outputs, perm[i])
if len(outputs) == numOutputs {
return outputs
}
}
Permute(&perm)
}
}

func Permute(input *[WIDTH]g.GoldilocksField) {
externalLinearLayer(input)
fullRounds(input, 0)
Expand Down Expand Up @@ -185,37 +213,45 @@ func sboxP(index int, state *[WIDTH]g.GoldilocksField) {
state[index] = g.MulF(tmpSixth, tmp)
}

const BlockSize = g.Bytes // BlockSize size that poseidon consumes
const BlockSize = g.Bytes * WIDTH // BlockSize size that poseidon consumes

type digest struct {
data []g.GoldilocksField
data []byte
len int
}

func NewPoseidon2() hash.Hash {
d := new(digest)
d.Reset()
return d
}

// Reset resets the Hash to its initial state.
func (d *digest) Reset() {
d.data = nil
d.data = d.data[:0]
d.len = 0
}

// Get element by element.
func (d *digest) Write(p []byte) (n int, err error) {
if len(p)%g.Bytes != 0 {
return 0, fmt.Errorf("input bytes len should be multiple of 8 but is %d", len(p))
}
d.data = append(d.data, p...)
d.len += len(p)

gArr := make([]g.GoldilocksField, len(p)/g.Bytes)
for i := 0; i < len(p); i += g.Bytes {
gArr[i/g.Bytes] = g.FromCanonicalLittleEndianBytesF(p[i : i+g.Bytes])
}
d.data = append(d.data, gArr...)
return len(p), nil
}

// Sum appends the current hash to b and returns the resulting slice.
// It does not change the underlying hash state.
func (d *digest) Sum(b []byte) []byte {
h := HashNToMNoPadBytes(d.data, 4)
d.Reset()

for _, elem := range h {
b = append(b, g.ToLittleEndianBytesF(elem)...)
}

return b
}

func (d *digest) Size() int {
return BlockSize
}
Expand All @@ -224,11 +260,3 @@ func (d *digest) Size() int {
func (d *digest) BlockSize() int {
return BlockSize
}

// Sum appends the current hash to b and returns the resulting slice.
// It does not change the underlying hash state.
func (d *digest) Sum(b []byte) []byte {
b = append(b, HashNToHashNoPad(d.data).ToLittleEndianBytes()...)
d.data = nil
return b
}
105 changes: 93 additions & 12 deletions hash/poseidon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"math/rand/v2"
"os"
"runtime"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -94,21 +95,60 @@ func TestPoseidon2Bench(t *testing.T) {
t.FailNow()
}

results := make([]g.GoldilocksField, 0, 4*len(inputs))
start := time.Now()
for _, input := range inputs {
res := poseidon2_plonky2.HashNToHashNoPad(input)
results = append(results, res[:]...)
for i := 0; i < 10; i++ {
PrintMemUsage()

results := make([]g.GoldilocksField, 0, 4*len(inputs))
start := time.Now()
for _, input := range inputs {
res := poseidon2_plonky2.HashNToHashNoPad(input)
results = append(results, res[:]...)
}
duration := time.Since(start)
t.Logf("HashNToHashNoPad plonky2 took %s for %d inputs", duration, totalInputs)

sha2 := sha256.New()
for _, res := range results {
sha2.Write(g.ToLittleEndianBytesF(res))
}
t.Logf("Hash: %x\n", sha2.Sum(nil))
}
duration := time.Since(start)
t.Logf("HashNToHashNoPad plonky2 took %s for %d inputs", duration, totalInputs)
}

sha2 := sha256.New()
for _, res := range results {
sha2.Write(g.ToLittleEndianBytesF(res))
func TestPoseidon2HasherBench(t *testing.T) {
inputs, err := readBenchInputsBytes("bench_vector")
totalInputs := len(inputs)
if err != nil {
t.Logf("Error: %v\n", err)
t.FailNow()
}
hash := sha2.Sum(nil)
t.Logf("Hash: %x\n", hash)

hasher := poseidon2_plonky2.NewPoseidon2()
start1 := time.Now()

for i := 0; i < 10; i++ {
PrintMemUsage()

results := make([]byte, 0, 4*8*len(inputs))
start := time.Now()
for _, input := range inputs {
for _, b := range input {
hasher.Write(b)
}
res := hasher.Sum(nil)
hasher.Reset()
results = append(results, res...)
}
duration := time.Since(start)
t.Logf("Hasher plonky2 took %s for %d inputs", duration, totalInputs)

sha2 := sha256.New()
sha2.Write(results)
t.Logf("Hash: %x\n", sha2.Sum(nil))
}

duration := time.Since(start1)
t.Logf("===> Hasher plonky2 took %s", duration)
}

func TestPoseidon2BenchOld(t *testing.T) {
Expand Down Expand Up @@ -167,6 +207,37 @@ func readBenchInputs(filename string) ([][]g.GoldilocksField, error) {
return inputs, nil
}

func readBenchInputsBytes(filename string) ([][][]byte, error) {
file, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
var inputs [][][]byte

for scanner.Scan() {
line := scanner.Text()
strVals := strings.Split(line, ",")
var input [][]byte
for _, strVal := range strVals {
val, err := strconv.ParseUint(strVal, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse uint64: %v", err)
}
input = append(input, g.ToLittleEndianBytesF(g.GoldilocksField(val)))
}
inputs = append(inputs, input)
}

if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read file: %v", err)
}

return inputs, nil
}

func readBenchInputsOld(filename string) ([][]g.Element, error) {
file, err := os.Open(filename)
if err != nil {
Expand Down Expand Up @@ -197,3 +268,13 @@ func readBenchInputsOld(filename string) ([][]g.Element, error) {

return inputs, nil
}

func PrintMemUsage() {
var m runtime.MemStats
runtime.ReadMemStats(&m)
// For info on each, see: https://golang.org/pkg/runtime/#MemStats
fmt.Printf("Alloc = %v Bytes", m.Alloc)
fmt.Printf("\tTotalAlloc = %v Bytes", m.TotalAlloc)
fmt.Printf("\tSys = %v Bytes", m.Sys)
fmt.Printf("\tNumGC = %v\n", m.NumGC)
}