Skip to content

Commit 1c22598

Browse files
authored
feat: Test engine to store elements in Montgomery format (#1695)
1 parent 7c23f06 commit 1c22598

File tree

3 files changed

+377
-73
lines changed

3 files changed

+377
-73
lines changed

test/blueprint_solver.go

Lines changed: 157 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,86 @@ import (
88
"github.com/consensys/gnark/internal/utils"
99
)
1010

11+
// modulus encapsulates field modulus and Montgomery conversion parameters
12+
type modulus[E constraint.Element] struct {
13+
q *big.Int
14+
rInv *big.Int
15+
qInv *big.Int // -q⁻¹ mod R, for Montgomery multiplication
16+
rMask *big.Int // 2^logR - 1, for efficient mod R operation
17+
logR uint
18+
bytesLen int
19+
}
20+
21+
// newModulus creates a typed modulus and computes Montgomery parameters
22+
func newModulus[E constraint.Element](q *big.Int) *modulus[E] {
23+
res := &modulus[E]{q: q}
24+
if smallfields.IsSmallField(q) {
25+
res.logR = 32
26+
res.bytesLen = 4
27+
} else {
28+
nbBits := q.BitLen()
29+
nbLimbs := (nbBits + 63) / 64
30+
res.logR = uint(nbLimbs * 64)
31+
res.bytesLen = 48
32+
}
33+
34+
// Compute R = 2^logR
35+
r := new(big.Int).Lsh(big.NewInt(1), res.logR)
36+
37+
// Compute R⁻¹ mod q
38+
res.rInv = new(big.Int).ModInverse(r, q)
39+
40+
// Compute q⁻¹ mod R
41+
res.qInv = new(big.Int).ModInverse(q, r)
42+
43+
// Compute qInv = -q⁻¹ mod R
44+
res.qInv.Sub(r, res.qInv)
45+
46+
// Compute rMask = R - 1 = 2^logR - 1 for efficient mod R
47+
res.rMask = new(big.Int).Sub(r, big.NewInt(1))
48+
49+
return res
50+
}
51+
52+
// toMontBigInt extracts element bytes as Montgomery form big.Int (no conversion)
53+
func (m *modulus[E]) toMontBigInt(f E) *big.Int {
54+
fBytes := f.Bytes()
55+
return new(big.Int).SetBytes(fBytes[:])
56+
}
57+
58+
// montBigIntToElement converts Montgomery big.Int directly to element (no conversion)
59+
func (m *modulus[E]) montBigIntToElement(mont *big.Int) E {
60+
bytes := mont.Bytes()
61+
if len(bytes) > m.bytesLen {
62+
panic("value too big")
63+
}
64+
paddedBytes := make([]byte, m.bytesLen)
65+
copy(paddedBytes[m.bytesLen-len(bytes):], bytes[:])
66+
return constraint.NewElement[E](paddedBytes[:])
67+
}
68+
69+
// ToBigInt converts element (Montgomery form) to canonical big.Int
70+
func (m *modulus[E]) ToBigInt(f E) *big.Int {
71+
x := m.toMontBigInt(f)
72+
x.Mul(x, m.rInv).Mod(x, m.q)
73+
return x
74+
}
75+
76+
// bigIntToElement converts canonical big.Int to Montgomery form element
77+
func (m *modulus[E]) bigIntToElement(b *big.Int) E {
78+
if b.Sign() == -1 {
79+
panic("negative value")
80+
}
81+
x := new(big.Int).Lsh(b, m.logR)
82+
x.Mod(x, m.q)
83+
return m.montBigIntToElement(x)
84+
}
85+
1186
// blueprintSolver is a constraint.Solver that can be used to test a circuit
1287
// it is a separate type to avoid method collisions with the engine.
1388
type blueprintSolver[E constraint.Element] struct {
1489
internalVariables []*big.Int
15-
q *big.Int
90+
*modulus[E]
1691
}
1792

1893
// implements constraint.Solver
@@ -40,47 +115,75 @@ func (s *blueprintSolver[E]) IsSolved(vID uint32) bool {
40115

41116
func (s *blueprintSolver[E]) FromInterface(i interface{}) E {
42117
b := utils.FromInterface(i)
43-
return s.toElement(&b)
118+
return s.bigIntToElement(&b)
44119
}
45120

46-
func (s *blueprintSolver[E]) ToBigInt(f E) *big.Int {
47-
r := new(big.Int)
48-
fBytes := f.Bytes()
49-
r.SetBytes(fBytes[:])
50-
return r
51-
}
52121
func (s *blueprintSolver[E]) Mul(a, b E) E {
53-
ba, bb := s.ToBigInt(a), s.ToBigInt(b)
54-
ba.Mul(ba, bb).Mod(ba, s.q)
55-
return s.toElement(ba)
122+
ba, bb := s.toMontBigInt(a), s.toMontBigInt(b)
123+
124+
// Montgomery multiplication using REDC algorithm
125+
// Computes (a·R) · (b·R) / R mod q = a·b·R mod q
126+
127+
// Step 1: t = a · b
128+
t := new(big.Int).Mul(ba, bb)
129+
130+
// Step 2: m = (t · qInv) mod R
131+
// Since R = 2^logR, we use bit masking for mod R
132+
// Optimize: reduce t mod R first to make multiplication smaller
133+
m := new(big.Int).And(t, s.modulus.rMask)
134+
m.Mul(m, s.modulus.qInv)
135+
m.And(m, s.modulus.rMask)
136+
137+
// Step 3: m = (t + m·q) / R
138+
m.Mul(m, s.modulus.q)
139+
m.Add(m, t)
140+
m.Rsh(m, s.modulus.logR) // divide by R = 2^logR
141+
142+
// Step 4: Final reduction
143+
if m.Cmp(s.modulus.q) >= 0 {
144+
m.Sub(m, s.modulus.q)
145+
}
146+
147+
return s.montBigIntToElement(m)
56148
}
57149
func (s *blueprintSolver[E]) Add(a, b E) E {
58-
ba, bb := s.ToBigInt(a), s.ToBigInt(b)
59-
ba.Add(ba, bb).Mod(ba, s.q)
60-
return s.toElement(ba)
150+
// Addition works the same in Montgomery form: (a·R + b·R) mod m = (a+b)·R mod m
151+
ba, bb := s.toMontBigInt(a), s.toMontBigInt(b)
152+
ba.Add(ba, bb).Mod(ba, s.modulus.q)
153+
return s.montBigIntToElement(ba)
61154
}
62155
func (s *blueprintSolver[E]) Sub(a, b E) E {
63-
ba, bb := s.ToBigInt(a), s.ToBigInt(b)
64-
ba.Sub(ba, bb).Mod(ba, s.q)
65-
return s.toElement(ba)
156+
// Subtraction works the same in Montgomery form: (a·R - b·R) mod m = (a-b)·R mod m
157+
ba, bb := s.toMontBigInt(a), s.toMontBigInt(b)
158+
ba.Sub(ba, bb).Mod(ba, s.modulus.q)
159+
return s.montBigIntToElement(ba)
66160
}
67161
func (s *blueprintSolver[E]) Neg(a E) E {
68-
ba := s.ToBigInt(a)
69-
ba.Neg(ba).Mod(ba, s.q)
70-
return s.toElement(ba)
162+
var zero E
163+
if a == zero {
164+
return zero
165+
}
166+
ba := s.toMontBigInt(a)
167+
ba.Sub(s.modulus.q, ba)
168+
return s.montBigIntToElement(ba)
71169
}
72170
func (s *blueprintSolver[E]) Inverse(a E) (E, bool) {
73-
ba := s.ToBigInt(a)
74-
r := ba.ModInverse(ba, s.q)
75-
return s.toElement(ba), r != nil
171+
r := s.toMontBigInt(a)
172+
r = r.ModInverse(r, s.modulus.q)
173+
if r == nil {
174+
var zero E
175+
return zero, false
176+
}
177+
r.Lsh(r, s.modulus.logR).
178+
Mod(r, s.modulus.q)
179+
return s.bigIntToElement(r), true
76180
}
77181
func (s *blueprintSolver[E]) One() E {
78182
b := new(big.Int).SetUint64(1)
79-
return s.toElement(b)
183+
return s.bigIntToElement(b)
80184
}
81185
func (s *blueprintSolver[E]) IsOne(a E) bool {
82-
b := s.ToBigInt(a)
83-
return b.IsUint64() && b.Uint64() == 1
186+
return a == s.One()
84187
}
85188

86189
func (s *blueprintSolver[E]) String(a E) string {
@@ -94,66 +197,58 @@ func (s *blueprintSolver[E]) Uint64(a E) (uint64, bool) {
94197
}
95198

96199
func (s *blueprintSolver[E]) Read(calldata []uint32) (E, int) {
97-
// We encoded big.Int as constraint.Element on 12 uint32 words.
200+
// Read canonical bytes from calldata, convert to Montgomery form element
98201
var r E
202+
var canonicalValue *big.Int
203+
var nWords int
204+
99205
switch t := any(&r).(type) {
100206
case *constraint.U64:
207+
// Read canonical bytes from calldata
101208
for i := 0; i < len(r); i++ {
102209
index := i * 2
103210
t[i] = uint64(calldata[index])<<32 | uint64(calldata[index+1])
104211
}
105-
return r, len(r) * 2
212+
canonicalValue = new(big.Int).SetBytes(r.Bytes())
213+
nWords = len(r) * 2
106214
case *constraint.U32:
107215
t[0] = uint32(calldata[0])
108-
return r, 1
216+
canonicalValue = new(big.Int).SetUint64(uint64(t[0]))
217+
nWords = 1
109218
default:
110219
panic("unsupported type")
111220
}
221+
222+
// Convert canonical to Montgomery and return as element
223+
return s.bigIntToElement(canonicalValue), nWords
112224
}
113225

114-
func (s *blueprintSolver[E]) toElement(b *big.Int) E {
115-
return bigIntToElement[E](b)
226+
// wrappedBigInt is a wrapper around big.Int to implement the frontend.CanonicalVariable interface
227+
type wrappedBigInt[E constraint.Element] struct {
228+
*big.Int
229+
*modulus[E]
116230
}
117231

118-
func bigIntToElement[E constraint.Element](b *big.Int) E {
119-
if b.Sign() == -1 {
232+
// Compress writes canonical bytes to calldata (no Montgomery conversion)
233+
func (w wrappedBigInt[E]) Compress(to *[]uint32) {
234+
if w.Sign() == -1 {
120235
panic("negative value")
121236
}
122-
bytes := b.Bytes()
123-
var bytesLen int
124-
var r E
125-
switch any(r).(type) {
126-
case constraint.U32:
127-
bytesLen = 4
128-
case constraint.U64:
129-
bytesLen = 48
130-
default:
131-
panic("unsupported type")
132-
}
133-
if len(bytes) > bytesLen {
134-
panic("value too big")
135-
}
136-
paddedBytes := make([]byte, bytesLen)
137-
copy(paddedBytes[bytesLen-len(bytes):], bytes[:])
138-
return constraint.NewElement[E](paddedBytes[:])
139-
}
140237

141-
// wrappedBigInt is a wrapper around big.Int to implement the frontend.CanonicalVariable interface
142-
type wrappedBigInt struct {
143-
*big.Int
144-
modulus *big.Int
145-
}
238+
// Use montBigIntToElement to handle byte padding and type switching
239+
e := w.modulus.montBigIntToElement(w.Int)
146240

147-
func (w wrappedBigInt) Compress(to *[]uint32) {
148-
if smallfields.IsSmallField(w.modulus) {
149-
e := bigIntToElement[constraint.U32](w.Int)
241+
// Extract uint32 values from the element
242+
switch e := any(e).(type) {
243+
case constraint.U32:
150244
*to = append(*to, uint32(e[0]))
151-
} else {
152-
e := bigIntToElement[constraint.U64](w.Int)
245+
case constraint.U64:
153246
// append the uint32 words to the slice
154-
for i := 0; i < len(e); i++ {
247+
for i := range e {
155248
*to = append(*to, uint32(e[i]>>32))
156249
*to = append(*to, uint32(e[i]&0xffffffff))
157250
}
251+
default:
252+
panic("unsupported type")
158253
}
159254
}

0 commit comments

Comments
 (0)