Skip to content

Commit e45e0a7

Browse files
authored
perf: optimize range checks for small number of small field ops (#1699)
1 parent 668ad24 commit e45e0a7

File tree

6 files changed

+102
-24
lines changed

6 files changed

+102
-24
lines changed

std/math/emulated/field.go

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@ import (
1919
)
2020

2121
const (
22+
// rangeCheckBaseLengthForSmallField is the base length used for range
23+
// checking when using small field optimization. We start enforcing
24+
// the base length only when the number of range checks exceeds
25+
// thresholdOptimizeOptimizedOverflow.
2226
rangeCheckBaseLengthForSmallField = 16
27+
// thresholdForInexactOverflow is the number of range checks after
28+
// which we start enforcing the base length for small field optimization.
29+
thresholdForInexactOverflow = 55000
2330
)
2431

2532
// Field holds the configuration for non-native field operations. The field
@@ -51,8 +58,11 @@ type Field[T FieldParams] struct {
5158

5259
log zerolog.Logger
5360

54-
constrainedLimbs map[[16]byte]struct{}
61+
// constrainedLimbs keeps track of already range checked limbs. The map
62+
// value indicates the range check width.
63+
constrainedLimbs map[[16]byte]int
5564
checker frontend.Rangechecker
65+
nbRangeChecks int
5666

5767
deferredChecks []deferredChecker
5868

@@ -81,7 +91,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
8191
f := &Field[T]{
8292
api: native,
8393
log: logger.Logger(),
84-
constrainedLimbs: make(map[[16]byte]struct{}),
94+
constrainedLimbs: make(map[[16]byte]int),
8595
checker: rangecheck.New(native),
8696
fParams: newStaticFieldParams[T](native.Compiler().Field()),
8797
}
@@ -93,15 +103,6 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
93103
}
94104
f.extensionApi = extapi
95105
}
96-
if f.useSmallFieldOptimization() {
97-
// in case of emulated small fields we use base length 16 to reduce
98-
// needing to range check for [v_lo, v_hi, 2*v_hi].
99-
//
100-
// But this means that hints could output values which are bigger than
101-
// the emulated modulus bitwidth (for example 31 bits). This means we
102-
// have to set the overflow of returned elements correctly.
103-
f.checker = rangecheck.New(native, rangecheck.WithBaseLength(rangeCheckBaseLengthForSmallField))
104-
}
105106

106107
// ensure prime is correctly set
107108
if f.fParams.IsPrime() {
@@ -265,7 +266,7 @@ func (f *Field[T]) enforceWidthConditional(a *Element[T]) (didConstrain bool) {
265266
// that we should enforce width for the whole element. But we
266267
// still iterate over all limbs just to mark them in the table.
267268
didConstrain = true
268-
f.constrainedLimbs[h] = struct{}{}
269+
break
269270
}
270271
} else {
271272
// we have no way of knowing if the limb has been constrained. To be
@@ -392,3 +393,48 @@ func (f *Field[T]) useSmallFieldOptimization() bool {
392393
})
393394
return f.smallFieldMode
394395
}
396+
397+
// rangeCheck performs a range check on v to ensure it fits in nbBits.
398+
// It also keeps track of the number of range checks done, and after a certain
399+
// threshold switches to using base length range checking for small field
400+
// optimization.
401+
//
402+
// It returns a boolean indicating if the range check was actually performed (i.e. if
403+
// the limb was not already constrained).
404+
func (f *Field[T]) rangeCheck(v frontend.Variable, nbBits int) bool {
405+
if h, ok := v.(interface{ HashCode() [16]byte }); ok {
406+
// if the variable has a hashcode, then we can use it to see if we have
407+
// already range checked it.
408+
hc := h.HashCode()
409+
if existingWidth, ok := f.constrainedLimbs[hc]; ok {
410+
// already range checked with a certain width
411+
if existingWidth <= nbBits {
412+
return false
413+
}
414+
}
415+
// mark as range checked
416+
f.constrainedLimbs[hc] = nbBits
417+
}
418+
// update the number of range checks done. This is only to keep track if we
419+
// should switch to the case where instead of exact width we range check
420+
// multiple of base length. This reduces number of range checks when
421+
// emulating small field.
422+
f.nbRangeChecks++
423+
424+
if f.nbRangeChecks == thresholdForInexactOverflow {
425+
// the threshold is reached, set the range checker to use base length.
426+
// Now we know that when constructing non-native elements, then we should
427+
// set overflow=f.smallAdditionalOverflow()
428+
if f.useSmallFieldOptimization() {
429+
// in case of emulated small fields we use base length 16 to reduce
430+
// needing to range check for [v_lo, v_hi, 2*v_hi].
431+
//
432+
// But this means that hints could output values which are bigger than
433+
// the emulated modulus bitwidth (for example 31 bits). This means we
434+
// have to set the overflow of returned elements correctly.
435+
f.checker = rangecheck.New(f.api, rangecheck.WithBaseLength(rangeCheckBaseLengthForSmallField))
436+
}
437+
}
438+
f.checker.Check(v, nbBits)
439+
return true
440+
}

std/math/emulated/field_assert.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) {
2727
// take only required bits from the most significant limb
2828
limbNbBits = ((f.fParams.Modulus().BitLen() - 1) % int(f.fParams.BitsPerLimb())) + 1
2929
}
30-
f.checker.Check(a.Limbs[i], limbNbBits)
30+
f.rangeCheck(a.Limbs[i], limbNbBits)
3131
}
3232
}
3333

@@ -37,7 +37,7 @@ func (f *Field[T]) smallEnforceWidth(a *Element[T], modWidth bool) {
3737
}
3838

3939
for i := range a.Limbs {
40-
f.checker.Check(a.Limbs[i], f.fParams.Modulus().BitLen()+int(a.overflow))
40+
f.rangeCheck(a.Limbs[i], f.fParams.Modulus().BitLen()+int(a.overflow))
4141
}
4242
}
4343

std/math/emulated/field_mul.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,9 +614,16 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error {
614614
return nil
615615
}
616616

617-
// Mul computes a*b and reduces it modulo the field order. The returned Element
618-
// has default number of limbs and zero overflow. If the result wouldn't fit
619-
// into Element, then locally reduces the inputs first. Doesn't mutate inputs.
617+
// Mul computes a*b. Depending on the emulated field it either reduces the result
618+
// modulo the field order or returns the full product.
619+
//
620+
// When emulating large field, the uses reducing multiplication by default.
621+
//
622+
// If the field is small (fits into single limb), then it uses non-reducing
623+
// multiplication by default for efficiency. It only falls back to reducing
624+
// multiplication when the overflow of the result would be too large.
625+
//
626+
// Doesn't mutate inputs.
620627
//
621628
// For multiplying by a constant, use [Field[T].MulConst] method which is more
622629
// efficient.
@@ -625,6 +632,12 @@ func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] {
625632
if a.isStrictZero() || b.isStrictZero() {
626633
return f.Zero()
627634
}
635+
if f.useSmallFieldOptimization() {
636+
// for small fields, it is more efficient to use non-reducing multiplication by default
637+
// we only fall back to reducing multiplication when modular reduction is necessary
638+
// to reduce the overflow
639+
return f.reduceAndOp(f.mulNoReduce, f.mulPreCondNoReduce, a, b)
640+
}
628641
return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCondReduced, a, b)
629642
}
630643

std/math/emulated/field_smallmul.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ func (f *Field[T]) smallMulMod(a, b *Element[T]) *Element[T] {
209209

210210
// Range check the remainder (quotient is range-checked via batched sum in check)
211211
modBits := f.fParams.Modulus().BitLen()
212-
f.checker.Check(r, modBits+f.smallAdditionalOverflow())
212+
f.rangeCheck(r, modBits+f.smallAdditionalOverflow())
213213

214214
// Compute the number of bits needed for the quotient.
215215
// For a*b = q*p + r:
@@ -413,8 +413,17 @@ func (f *Field[T]) toSingleLimbElement(a *Element[T]) *Element[T] {
413413
// range checking, but define that the non-native small field element can have
414414
// some additional overflow bits to accommodate this difference.
415415
func (f *Field[T]) smallAdditionalOverflow() int {
416+
// when we emulate large field, then we always construct elements with exact
417+
// overflow
416418
if !f.useSmallFieldOptimization() {
417419
return 0
418420
}
421+
// when we haven't performed too many range checks, then we still use exact
422+
// overflow
423+
if f.nbRangeChecks < thresholdForInexactOverflow {
424+
return 0
425+
}
426+
// otherwise, we use the additional overflow which reduced number of
427+
// decompositions during range checking
419428
return (rangeCheckBaseLengthForSmallField - (f.fParams.Modulus().BitLen() % rangeCheckBaseLengthForSmallField)) % rangeCheckBaseLengthForSmallField
420429
}

std/math/emulated/smallfield_test.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,15 @@ func BenchmarkSmallFieldMulConstraints(b *testing.B) {
425425
b.Run(bc.name, func(b *testing.B) {
426426
circuit := &SmallFieldMulBenchCircuit{A: make([]Element[emparams.KoalaBear], bc.nbMuls)}
427427

428+
for b.Loop() {
429+
// for some reason, when we don't run the loop here, then the benchmark suite
430+
// runs the whole benchmark multiple times. I guess it has something to do
431+
// with the `b.Run` above (i.e. it parallelizes etc). To avoid this, we run an
432+
// empty b.Loop() here to ensure we only run the compile once.
433+
//
434+
// this adds overhead as the `b.Loop()` will be run for `benchtime` period, but
435+
// by default it is small. Otherwise the benchmark will be very slow.
436+
}
428437
csr1, err := frontend.Compile(ecc.BLS12_377.ScalarField(), r1cs.NewBuilder, circuit)
429438
if err != nil {
430439
b.Fatal(err)
@@ -440,6 +449,7 @@ func BenchmarkSmallFieldMulConstraints(b *testing.B) {
440449
constraintsSCSPerMul := float64(css.GetNbConstraints()) / float64(bc.nbMuls)
441450
b.ReportMetric(constraintsSCSPerMul, "scs_constraints/mul")
442451
b.ReportMetric(float64(css.GetNbConstraints()), "scs_total_constraints")
452+
b.ReportMetric(0.0, "ns/op") // avoid ns/op reporting as we don't measure time here
443453
})
444454
}
445455
}
@@ -483,11 +493,11 @@ func (c *MaliciousMulCircuit) Define(api frontend.API) error {
483493
}
484494

485495
// 5 multiplications: ((A*B) * (C*D)) * (E*F)
486-
ab := f.Mul(&c.A, &c.B)
487-
cd := f.Mul(&c.C, &c.D)
488-
ef := f.Mul(&c.E, &c.F)
489-
abcd := f.Mul(ab, cd)
490-
result := f.Mul(abcd, ef)
496+
ab := f.MulMod(&c.A, &c.B)
497+
cd := f.MulMod(&c.C, &c.D)
498+
ef := f.MulMod(&c.E, &c.F)
499+
abcd := f.MulMod(ab, cd)
500+
result := f.MulMod(abcd, ef)
491501
f.AssertIsEqual(result, &c.Result)
492502
return nil
493503
}

std/math/emulated/subtraction_padding.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (f *Field[T]) computeSubPaddingHint(overflow uint, nbLimbs uint, modulus *E
119119
// at least native_width-overflow) and should be nbBits+overflow+1 bits
120120
// wide (as expected padding is one bit wider than the maximum allowed
121121
// subtraction limb).
122-
f.checker.Check(f.api.Sub(res[i], maxLimb), int(f.fParams.BitsPerLimb()+overflow+1))
122+
f.rangeCheck(f.api.Sub(res[i], maxLimb), int(f.fParams.BitsPerLimb()+overflow+1))
123123
}
124124

125125
// ensure that condition 1 holds

0 commit comments

Comments
 (0)