Skip to content

Commit 5442047

Browse files
authored
port some optimizations from plonky2 fork (#18)
elliottech/plonky2#10 It is %20-25 faster
1 parent c67de0a commit 5442047

File tree

4 files changed

+213
-28
lines changed

4 files changed

+213
-28
lines changed

field/field_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,80 @@ func TestSquareF(t *testing.T) {
187187
}
188188
}
189189

190+
func TestMulAccF(t *testing.T) {
191+
for _, x := range inputs {
192+
for _, y := range inputs {
193+
for _, z := range inputs {
194+
fX := g.GoldilocksField(x)
195+
fY := g.GoldilocksField(y)
196+
fZ := g.GoldilocksField(z)
197+
mulAcc := g.MulAccF(fX, fY, fZ).ToCanonicalUint64()
198+
expected := SumMod(x, MulMod(y, z))
199+
if mulAcc != expected {
200+
t.Fatalf("Expected %d + %d * %d = %d, but got %d", x, y, z, expected, mulAcc)
201+
}
202+
}
203+
}
204+
}
205+
}
206+
207+
func TestReduce128Bit(t *testing.T) {
208+
testValues := []g.UInt128{
209+
{0, 0},
210+
{1, 0},
211+
{math.MaxUint64, math.MaxUint64},
212+
{0, g.ORDER + 1},
213+
{0, g.ORDER - 1},
214+
{0, 1},
215+
{0, math.MaxUint64},
216+
{math.MaxUint64, 0},
217+
}
218+
219+
for _, val := range testValues {
220+
reduced := g.Reduce128Bit(val)
221+
bigVal := new(big.Int).SetUint64(val.Hi)
222+
bigVal.Lsh(bigVal, 64)
223+
bigVal.Add(bigVal, new(big.Int).SetUint64(val.Lo))
224+
bigOrder := new(big.Int).SetUint64(g.ORDER)
225+
bigVal.Mod(bigVal, bigOrder)
226+
expected := bigVal.Uint64()
227+
if reduced.ToCanonicalUint64() != expected {
228+
t.Fatalf("Expected reduction of %v to be %d, but got %d", val, expected, reduced.ToCanonicalUint64())
229+
}
230+
}
231+
}
232+
233+
func TestReduce96Bit(t *testing.T) {
234+
testValues := []g.UInt128{
235+
{0, 0},
236+
{1, 0},
237+
{math.MaxUint32, math.MaxUint64},
238+
{0, g.ORDER + 1},
239+
{0, g.ORDER - 1},
240+
{0, 1},
241+
{0, math.MaxUint64},
242+
{math.MaxUint32, 0},
243+
}
244+
245+
for _, val := range testValues {
246+
reduced := g.Reduce96Bit(val)
247+
bigVal := new(big.Int).SetUint64(val.Hi)
248+
bigVal.Lsh(bigVal, 64)
249+
bigVal.Add(bigVal, new(big.Int).SetUint64(val.Lo))
250+
251+
if bigVal.BitLen() > 96 {
252+
t.Fatalf("Input %v exceeds 96 bits", val)
253+
}
254+
255+
bigOrder := new(big.Int).SetUint64(g.ORDER)
256+
bigVal.Mod(bigVal, bigOrder)
257+
expected := bigVal.Uint64()
258+
if reduced.ToCanonicalUint64() != expected {
259+
t.Fatalf("Expected reduction of %v to be %d, but got %d", val, expected, reduced.ToCanonicalUint64())
260+
}
261+
}
262+
}
263+
190264
func TestSubFDoubleWraparound(t *testing.T) {
191265
/*
192266
let (a, b) = (F::from_canonical_u64((F::ORDER + 1u64) / 2u64), F::TWO);

field/goldilocks/goldilocks_plonky2.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ func (z GoldilocksField) ToCanonicalUint64() uint64 {
4747
return x
4848
}
4949

50+
// lhs, rhs in non-canonical form
5051
func AddF(lhs, rhs GoldilocksField) GoldilocksField {
5152
sum, over := bits.Add64(uint64(lhs), uint64(rhs), 0)
5253
sum, over = bits.Add64(sum, over*EPSILON, 0)
@@ -58,10 +59,18 @@ func AddF(lhs, rhs GoldilocksField) GoldilocksField {
5859
return GoldilocksField(sum)
5960
}
6061

62+
// Assuming lhs or rhs is in the field, i.e. x < ORDER and other in non-canonical form(u64). This assumption can be used to remove second overflow check.
63+
func AddCanonicalUint64(lhs GoldilocksField, rhs uint64) GoldilocksField {
64+
sum, over := bits.Add64(uint64(lhs), uint64(rhs), 0)
65+
// if overflowed, sum := lhs + rhs - 2^64 => sum + EPSILON = lhs + rhs - 2^64 + 2^32 -1 = lhs + rhs - ORDER < ORDER + 2^64 - ORDER = 2^64, so there is no overflow in this case.
66+
return GoldilocksField(sum + over*EPSILON)
67+
}
68+
6169
func DoubleF(lhs GoldilocksField) GoldilocksField {
6270
return AddF(lhs, lhs)
6371
}
6472

73+
// lhs, rhs in non-canonical form
6574
func SubF(lhs, rhs GoldilocksField) GoldilocksField {
6675
diff, borrow := bits.Sub64(uint64(lhs), uint64(rhs), 0)
6776
diff, borrow = bits.Sub64(diff, borrow*EPSILON, 0)
@@ -73,6 +82,7 @@ func SubF(lhs, rhs GoldilocksField) GoldilocksField {
7382
return GoldilocksField(diff)
7483
}
7584

85+
// lhs, rhs in non-canonical form
7686
func MulF(lhs, rhs GoldilocksField) GoldilocksField {
7787
x_hi, x_lo := bits.Mul64(uint64(lhs), uint64(rhs))
7888

@@ -95,6 +105,12 @@ func SquareF(x GoldilocksField) GoldilocksField {
95105
return MulF(x, x)
96106
}
97107

108+
// Returns self + x * y
109+
func MulAccF(self, x, y GoldilocksField) GoldilocksField {
110+
// u64 + u64 * u64 cannot overflow.
111+
return Reduce128Bit(AddUInt128(AsUInt128(self), MulUInt64(uint64(x), uint64(y))))
112+
}
113+
98114
func ExpPowerOf2(x GoldilocksField, n uint) GoldilocksField {
99115
z := x
100116
for i := uint(0); i < n; i++ {
@@ -121,6 +137,7 @@ func SampleF() GoldilocksField {
121137
return GoldilocksField(rng.Uint64())
122138
}
123139

140+
// Canonical representation
124141
func ToLittleEndianBytesF(z GoldilocksField) []byte {
125142
res := make([]byte, Bytes)
126143
binary.LittleEndian.PutUint64(res, z.ToCanonicalUint64())
@@ -131,6 +148,53 @@ func FromCanonicalLittleEndianBytesF(b []byte) GoldilocksField {
131148
return GoldilocksField(binary.LittleEndian.Uint64(b))
132149
}
133150

151+
type UInt128 struct {
152+
Hi, Lo uint64
153+
}
154+
155+
// NonCanonical conversion
156+
func AsUInt128(f GoldilocksField) UInt128 {
157+
u := uint64(f)
158+
return UInt128{0, u}
159+
}
160+
161+
func AddUInt128(x, y UInt128) UInt128 {
162+
var carry uint64
163+
var z UInt128
164+
z.Lo, carry = bits.Add64(x.Lo, y.Lo, 0)
165+
z.Hi = x.Hi + y.Hi + carry
166+
return z
167+
}
168+
169+
func MulUInt64(x, y uint64) UInt128 {
170+
hi, lo := bits.Mul64(x, y)
171+
return UInt128{hi, lo}
172+
}
173+
174+
// Assumes x is 96-bit number
175+
func Reduce96Bit(x UInt128) GoldilocksField {
176+
t1 := x.Hi * EPSILON
177+
resWrapped, carry := bits.Add64(x.Lo, t1, 0)
178+
179+
return GoldilocksField(resWrapped) + GoldilocksField(carry*EPSILON)
180+
}
181+
182+
func Reduce128Bit(x UInt128) GoldilocksField {
183+
x_hi_hi := x.Hi >> 32
184+
x_hi_lo := x.Hi & EPSILON
185+
186+
t0, borrow := bits.Sub64(x.Lo, x_hi_hi, 0)
187+
if borrow == 1 {
188+
branchHint()
189+
t0 -= EPSILON
190+
}
191+
t1 := x_hi_lo * EPSILON
192+
193+
resWrapped, carry := bits.Add64(t0, t1, 0)
194+
t2 := resWrapped + EPSILON*carry
195+
return GoldilocksField(t2)
196+
}
197+
134198
// func (z *GoldilocksField) Inverse(x *GoldilocksField) *GoldilocksField {
135199
// if x.IsZero() {
136200
// z.SetZero()

hash/poseidon2_goldilocks_plonky2/poseidon2.go

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -149,51 +149,80 @@ func partialRounds(state *[WIDTH]g.GoldilocksField) {
149149
}
150150

151151
func externalLinearLayer(s *[WIDTH]g.GoldilocksField) {
152-
for i := 0; i < 3; i++ { // 4 size window
153-
var t0, t1, t2, t3, t4, t5, t6 g.GoldilocksField
154-
t0 = g.AddF(s[4*i], s[4*i+1]) // s0+s1
155-
t1 = g.AddF(s[4*i+2], s[4*i+3]) // s2+s3
156-
t2 = g.AddF(t0, t1) // t0+t1 = s0+s1+s2+s3
157-
t3 = g.AddF(t2, s[4*i+1]) // t2+s1 = s0+2s1+s2+s3
158-
t4 = g.AddF(t2, s[4*i+3]) // t2+s3 = s0+s1+s2+2s3
159-
t5 = g.DoubleF(s[4*i]) // 2s0
160-
t6 = g.DoubleF(s[4*i+2]) // 2s2
161-
s[4*i] = g.AddF(t3, t0)
162-
s[4*i+1] = g.AddF(t6, t3)
163-
s[4*i+2] = g.AddF(t1, t4)
164-
s[4*i+3] = g.AddF(t5, t4)
152+
s128 := [WIDTH]g.UInt128{}
153+
for i := 0; i < WIDTH; i++ {
154+
s128[i] = g.AsUInt128(s[i])
165155
}
166156

167-
sums := [4]g.GoldilocksField{}
168-
for k := 0; k < 4; k++ {
169-
for j := 0; j < WIDTH; j += 4 {
170-
sums[k] = g.AddF(sums[k], s[j+k])
171-
}
172-
}
157+
externalLinearLayer128(&s128)
158+
173159
for i := 0; i < WIDTH; i++ {
174-
s[i] = g.AddF(s[i], sums[i%4])
160+
s[i] = g.Reduce96Bit(s128[i])
175161
}
176162
}
177163

178-
func internalLinearLayer(state *[WIDTH]g.GoldilocksField) {
179-
sum := state[0]
180-
for i := 1; i < WIDTH; i++ {
181-
sum = g.AddF(sum, state[i])
164+
func externalLinearLayer128(s *[WIDTH]g.UInt128) {
165+
for i := 0; i < WIDTH; i += 4 {
166+
t01 := g.AddUInt128(s[i], s[i+1])
167+
t23 := g.AddUInt128(s[i+2], s[i+3])
168+
t0123 := g.AddUInt128(t01, t23)
169+
170+
x0 := s[i]
171+
x2 := s[i+2]
172+
173+
s[i] = g.AddUInt128(g.AddUInt128(t0123, t01), s[i+1])
174+
s[i+1] = g.AddUInt128(g.AddUInt128(g.AddUInt128(t0123, s[i+1]), x2), x2)
175+
s[i+2] = g.AddUInt128(g.AddUInt128(t0123, t23), s[i+3])
176+
s[i+3] = g.AddUInt128(g.AddUInt128(g.AddUInt128(t0123, s[i+3]), x0), x0)
182177
}
178+
179+
sums := [4]g.UInt128{}
180+
for i := 0; i < 4; i++ {
181+
sums[i] = g.AddUInt128(g.AddUInt128(s[i], s[i+4]), s[i+8])
182+
}
183+
183184
for i := 0; i < WIDTH; i++ {
184-
state[i] = g.MulF(state[i], MATRIX_DIAG_12_U64[i])
185-
state[i] = g.AddF(state[i], sum)
185+
s[i] = g.AddUInt128(s[i], sums[i%4])
186186
}
187187
}
188188

189+
func internalLinearLayer(state *[WIDTH]g.GoldilocksField) {
190+
sum := g.AsUInt128(state[0])
191+
sum = g.AddUInt128(sum, g.AsUInt128(state[1]))
192+
sum = g.AddUInt128(sum, g.AsUInt128(state[2]))
193+
sum = g.AddUInt128(sum, g.AsUInt128(state[3]))
194+
sum = g.AddUInt128(sum, g.AsUInt128(state[4]))
195+
sum = g.AddUInt128(sum, g.AsUInt128(state[5]))
196+
sum = g.AddUInt128(sum, g.AsUInt128(state[6]))
197+
sum = g.AddUInt128(sum, g.AsUInt128(state[7]))
198+
sum = g.AddUInt128(sum, g.AsUInt128(state[8]))
199+
sum = g.AddUInt128(sum, g.AsUInt128(state[9]))
200+
sum = g.AddUInt128(sum, g.AsUInt128(state[10]))
201+
sum = g.AddUInt128(sum, g.AsUInt128(state[11]))
202+
sumF := g.Reduce96Bit(sum)
203+
204+
state[0] = g.MulAccF(sumF, state[0], MATRIX_DIAG_12_U64[0])
205+
state[1] = g.MulAccF(sumF, state[1], MATRIX_DIAG_12_U64[1])
206+
state[2] = g.MulAccF(sumF, state[2], MATRIX_DIAG_12_U64[2])
207+
state[3] = g.MulAccF(sumF, state[3], MATRIX_DIAG_12_U64[3])
208+
state[4] = g.MulAccF(sumF, state[4], MATRIX_DIAG_12_U64[4])
209+
state[5] = g.MulAccF(sumF, state[5], MATRIX_DIAG_12_U64[5])
210+
state[6] = g.MulAccF(sumF, state[6], MATRIX_DIAG_12_U64[6])
211+
state[7] = g.MulAccF(sumF, state[7], MATRIX_DIAG_12_U64[7])
212+
state[8] = g.MulAccF(sumF, state[8], MATRIX_DIAG_12_U64[8])
213+
state[9] = g.MulAccF(sumF, state[9], MATRIX_DIAG_12_U64[9])
214+
state[10] = g.MulAccF(sumF, state[10], MATRIX_DIAG_12_U64[10])
215+
state[11] = g.MulAccF(sumF, state[11], MATRIX_DIAG_12_U64[11])
216+
}
217+
189218
func addRC(state *[WIDTH]g.GoldilocksField, externalRound int) {
190219
for i := 0; i < WIDTH; i++ {
191-
state[i] = g.AddF(state[i], EXTERNAL_CONSTANTS[externalRound][i])
220+
state[i] = g.AddCanonicalUint64(state[i], uint64(EXTERNAL_CONSTANTS[externalRound][i]))
192221
}
193222
}
194223

195224
func addRCI(state *[WIDTH]g.GoldilocksField, round int) {
196-
state[0] = g.AddF(state[0], INTERNAL_CONSTANTS[round])
225+
state[0] = g.AddCanonicalUint64(state[0], uint64(INTERNAL_CONSTANTS[round]))
197226
}
198227

199228
func sbox(state *[WIDTH]g.GoldilocksField) {

hash/poseidon2_goldilocks_plonky2/poseidon2_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,21 @@ func TestHashToQuinticExtension(t *testing.T) {
283283
}
284284
}
285285
}
286+
287+
func TestConstantsAreInTheField(t *testing.T) {
288+
for r := 0; r < ROUNDS_F; r++ {
289+
for i := 0; i < WIDTH; i++ {
290+
if uint64(EXTERNAL_CONSTANTS[r][i]) >= g.ORDER {
291+
t.Logf("External constant at round %d, index %d is not in the field: %d", r, i, EXTERNAL_CONSTANTS[r][i])
292+
t.Fail()
293+
}
294+
}
295+
}
296+
297+
for r := 0; r < ROUNDS_P; r++ {
298+
if uint64(INTERNAL_CONSTANTS[r]) >= g.ORDER {
299+
t.Logf("Internal constant at round %d is not in the field: %d", r, INTERNAL_CONSTANTS[r])
300+
t.Fail()
301+
}
302+
}
303+
}

0 commit comments

Comments
 (0)