Skip to content

Commit d2fa6c3

Browse files
weijiguoggq89ivokub
authored
Perf: optimize selector.Mux with recursive BinaryMux for various sizes (#1420)
Co-authored-by: Albert·Gou <ggq89@qq.com> Co-authored-by: Ivo Kubjas <ivo.kubjas@consensys.net>
1 parent aab575e commit d2fa6c3

File tree

4 files changed

+191
-82
lines changed

4 files changed

+191
-82
lines changed

internal/stats/latest_stats.csv

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,73 @@ scalar_mul_secp256k1,bls24_315,plonk,0,0
251251
scalar_mul_secp256k1,bls24_317,plonk,0,0
252252
scalar_mul_secp256k1,bw6_761,plonk,0,0
253253
scalar_mul_secp256k1,bw6_633,plonk,0,0
254+
selector/binaryMux_4,bn254,groth16,5,3
255+
selector/binaryMux_4,bls12_377,groth16,5,3
256+
selector/binaryMux_4,bls12_381,groth16,5,3
257+
selector/binaryMux_4,bls24_315,groth16,5,3
258+
selector/binaryMux_4,bls24_317,groth16,5,3
259+
selector/binaryMux_4,bw6_761,groth16,5,3
260+
selector/binaryMux_4,bw6_633,groth16,5,3
261+
selector/binaryMux_4,bn254,plonk,11,9
262+
selector/binaryMux_4,bls12_377,plonk,11,9
263+
selector/binaryMux_4,bls12_381,plonk,11,9
264+
selector/binaryMux_4,bls24_315,plonk,11,9
265+
selector/binaryMux_4,bls24_317,plonk,11,9
266+
selector/binaryMux_4,bw6_761,plonk,11,9
267+
selector/binaryMux_4,bw6_633,plonk,11,9
268+
selector/binaryMux_8,bn254,groth16,10,7
269+
selector/binaryMux_8,bls12_377,groth16,10,7
270+
selector/binaryMux_8,bls12_381,groth16,10,7
271+
selector/binaryMux_8,bls24_315,groth16,10,7
272+
selector/binaryMux_8,bls24_317,groth16,10,7
273+
selector/binaryMux_8,bw6_761,groth16,10,7
274+
selector/binaryMux_8,bw6_633,groth16,10,7
275+
selector/binaryMux_8,bn254,plonk,24,21
276+
selector/binaryMux_8,bls12_377,plonk,24,21
277+
selector/binaryMux_8,bls12_381,plonk,24,21
278+
selector/binaryMux_8,bls24_315,plonk,24,21
279+
selector/binaryMux_8,bls24_317,plonk,24,21
280+
selector/binaryMux_8,bw6_761,plonk,24,21
281+
selector/binaryMux_8,bw6_633,plonk,24,21
282+
selector/mux_3,bn254,groth16,8,6
283+
selector/mux_3,bls12_377,groth16,8,6
284+
selector/mux_3,bls12_381,groth16,8,6
285+
selector/mux_3,bls24_315,groth16,8,6
286+
selector/mux_3,bls24_317,groth16,8,6
287+
selector/mux_3,bw6_761,groth16,8,6
288+
selector/mux_3,bw6_633,groth16,8,6
289+
selector/mux_3,bn254,plonk,15,13
290+
selector/mux_3,bls12_377,plonk,15,13
291+
selector/mux_3,bls12_381,plonk,15,13
292+
selector/mux_3,bls24_315,plonk,15,13
293+
selector/mux_3,bls24_317,plonk,15,13
294+
selector/mux_3,bw6_761,plonk,15,13
295+
selector/mux_3,bw6_633,plonk,15,13
296+
selector/mux_4,bn254,groth16,6,5
297+
selector/mux_4,bls12_377,groth16,6,5
298+
selector/mux_4,bls12_381,groth16,6,5
299+
selector/mux_4,bls24_315,groth16,6,5
300+
selector/mux_4,bls24_317,groth16,6,5
301+
selector/mux_4,bw6_761,groth16,6,5
302+
selector/mux_4,bw6_633,groth16,6,5
303+
selector/mux_4,bn254,plonk,13,12
304+
selector/mux_4,bls12_377,plonk,13,12
305+
selector/mux_4,bls12_381,plonk,13,12
306+
selector/mux_4,bls24_315,plonk,13,12
307+
selector/mux_4,bls24_317,plonk,13,12
308+
selector/mux_4,bw6_761,plonk,13,12
309+
selector/mux_4,bw6_633,plonk,13,12
310+
selector/mux_5,bn254,groth16,12,10
311+
selector/mux_5,bls12_377,groth16,12,10
312+
selector/mux_5,bls12_381,groth16,12,10
313+
selector/mux_5,bls24_315,groth16,12,10
314+
selector/mux_5,bls24_317,groth16,12,10
315+
selector/mux_5,bw6_761,groth16,12,10
316+
selector/mux_5,bw6_633,groth16,12,10
317+
selector/mux_5,bn254,plonk,25,23
318+
selector/mux_5,bls12_377,plonk,25,23
319+
selector/mux_5,bls12_381,plonk,25,23
320+
selector/mux_5,bls24_315,plonk,25,23
321+
selector/mux_5,bls24_317,plonk,25,23
322+
selector/mux_5,bw6_761,plonk,25,23
323+
selector/mux_5,bw6_633,plonk,25,23

internal/stats/snippet.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/consensys/gnark/std/hash/mimc"
1717
"github.com/consensys/gnark/std/math/bits"
1818
"github.com/consensys/gnark/std/math/emulated"
19+
"github.com/consensys/gnark/std/selector"
1920
)
2021

2122
var (
@@ -311,6 +312,26 @@ func initSnippets() {
311312

312313
}, ecc.BN254)
313314

315+
registerSnippet("selector/mux_3", func(api frontend.API, newVariable func() frontend.Variable) {
316+
selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable())
317+
})
318+
319+
registerSnippet("selector/mux_4", func(api frontend.API, newVariable func() frontend.Variable) {
320+
selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable(), newVariable())
321+
})
322+
323+
registerSnippet("selector/mux_5", func(api frontend.API, newVariable func() frontend.Variable) {
324+
selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable())
325+
})
326+
327+
registerSnippet("selector/binaryMux_4", func(api frontend.API, newVariable func() frontend.Variable) {
328+
selector.BinaryMux(api, []frontend.Variable{newVariable(), newVariable()}, []frontend.Variable{newVariable(), newVariable(), newVariable(), newVariable()})
329+
})
330+
331+
registerSnippet("selector/binaryMux_8", func(api frontend.API, newVariable func() frontend.Variable) {
332+
selector.BinaryMux(api, []frontend.Variable{newVariable(), newVariable(), newVariable()}, []frontend.Variable{newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable(), newVariable()})
333+
})
334+
314335
}
315336

316337
type snippetCircuit struct {

std/selector/multiplexer.go

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ package selector
1111

1212
import (
1313
"fmt"
14+
"math/big"
15+
binary "math/bits"
16+
1417
"github.com/consensys/gnark/constraint/solver"
1518
"github.com/consensys/gnark/frontend"
1619
"github.com/consensys/gnark/std/math/bits"
17-
"math/big"
18-
binary "math/bits"
20+
"github.com/consensys/gnark/std/math/cmp"
1921
)
2022

2123
func init() {
@@ -53,12 +55,45 @@ func Map(api frontend.API, queryKey frontend.Variable,
5355
// sel needs to be between 0 and n - 1 (inclusive), where n is the number of
5456
// inputs, otherwise the proof will fail.
5557
func Mux(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable) frontend.Variable {
56-
// we use BinaryMux when len(inputs) is a power of 2.
57-
if binary.OnesCount(uint(len(inputs))) == 1 {
58-
selBits := bits.ToBinary(api, sel, bits.WithNbDigits(binary.Len(uint(len(inputs)))-1))
58+
n := uint(len(inputs))
59+
if n == 1 {
60+
api.AssertIsEqual(sel, 0)
61+
return inputs[0]
62+
}
63+
nbBits := binary.Len(n - 1) // we use n-1 as sel is 0-indexed
64+
selBits := bits.ToBinary(api, sel, bits.WithNbDigits(nbBits)) // binary decomposition ensures sel < 2^nbBits
65+
66+
// We use BinaryMux when len(inputs) is a power of 2.
67+
if binary.OnesCount(n) == 1 {
5968
return BinaryMux(api, selBits, inputs)
6069
}
61-
return dotProduct(api, inputs, Decoder(api, len(inputs), sel))
70+
71+
bcmp := cmp.NewBoundedComparator(api, big.NewInt((1<<nbBits)-1), false)
72+
bcmp.AssertIsLessEq(sel, n-1)
73+
74+
// Otherwise, we split inputs into two sub-arrays, such that the first part's length is 2's power
75+
return muxRecursive(api, selBits, inputs)
76+
}
77+
78+
func muxRecursive(api frontend.API,
79+
selBits []frontend.Variable, inputs []frontend.Variable) frontend.Variable {
80+
81+
nbBits := len(selBits)
82+
leftCount := uint(1 << (nbBits - 1))
83+
left := BinaryMux(api, selBits[:nbBits-1], inputs[:leftCount])
84+
85+
rightCount := uint(len(inputs)) - leftCount
86+
nbRightBits := binary.Len(rightCount)
87+
88+
var right frontend.Variable
89+
if binary.OnesCount(rightCount) == 1 {
90+
right = BinaryMux(api, selBits[:nbRightBits-1], inputs[leftCount:])
91+
} else {
92+
right = muxRecursive(api, selBits[:nbRightBits], inputs[leftCount:])
93+
}
94+
95+
msb := selBits[nbBits-1]
96+
return api.Select(msb, right, left)
6297
}
6398

6499
// KeyDecoder is a decoder that associates keys to its output wires. It outputs

std/selector/multiplexer_test.go

Lines changed: 59 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
1-
package selector_test
1+
package selector
22

33
import (
4+
"fmt"
5+
"math/rand/v2"
46
"testing"
5-
6-
"github.com/consensys/gnark-crypto/ecc"
7-
"github.com/consensys/gnark/frontend/cs/r1cs"
7+
"time"
88

99
"github.com/consensys/gnark/frontend"
10-
"github.com/consensys/gnark/std/selector"
1110
"github.com/consensys/gnark/test"
1211
)
1312

1413
type muxCircuit struct {
15-
SEL frontend.Variable
16-
I0, I1, I2, I3, I4 frontend.Variable
17-
OUT frontend.Variable
14+
Sel frontend.Variable
15+
Input []frontend.Variable
16+
Expected frontend.Variable
1817
}
1918

2019
func (c *muxCircuit) Define(api frontend.API) error {
2120

22-
out := selector.Mux(api, c.SEL, c.I0, c.I1, c.I2, c.I3, c.I4)
23-
24-
api.AssertIsEqual(out, c.OUT)
21+
out := Mux(api, c.Sel, c.Input...)
22+
api.AssertIsEqual(out, c.Expected)
2523

2624
return nil
2725
}
@@ -34,46 +32,63 @@ type ignoredOutputMuxCircuit struct {
3432

3533
func (c *ignoredOutputMuxCircuit) Define(api frontend.API) error {
3634
// We ignore the output
37-
_ = selector.Mux(api, c.SEL, c.I0, c.I1, c.I2)
35+
_ = Mux(api, c.SEL, c.I0, c.I1, c.I2)
3836

3937
return nil
4038
}
4139

42-
type mux2to1Circuit struct {
43-
SEL frontend.Variable
44-
I0, I1 frontend.Variable
45-
OUT frontend.Variable
46-
}
47-
48-
func (c *mux2to1Circuit) Define(api frontend.API) error {
49-
// We ignore the output
50-
out := selector.Mux(api, c.SEL, c.I0, c.I1)
51-
api.AssertIsEqual(out, c.OUT)
52-
return nil
53-
}
54-
55-
type mux4to1Circuit struct {
56-
SEL frontend.Variable
57-
In [4]frontend.Variable
58-
OUT frontend.Variable
59-
}
60-
61-
func (c *mux4to1Circuit) Define(api frontend.API) error {
62-
out := selector.Mux(api, c.SEL, c.In[:]...)
63-
api.AssertIsEqual(out, c.OUT)
64-
return nil
40+
func testMux(assert *test.Assert, len int, sel int) {
41+
// seed the random generator with the current time. Good enough for tests.
42+
rng := rand.New(rand.NewPCG(uint64(time.Now().Unix()), 1)) //nolint G404
43+
circuit := &muxCircuit{
44+
Input: make([]frontend.Variable, len),
45+
}
46+
47+
inputs := make([]frontend.Variable, len)
48+
for i := 0; i < len; i++ {
49+
inputs[i] = frontend.Variable(rng.Uint64())
50+
}
51+
// out-range invalid selector
52+
outRangeSel := uint64(len) + rng.Uint64N(100)
53+
opts := []test.TestingOption{
54+
test.WithValidAssignment(&muxCircuit{
55+
Sel: sel,
56+
Input: inputs,
57+
Expected: inputs[sel],
58+
}),
59+
test.WithInvalidAssignment(&muxCircuit{
60+
Sel: outRangeSel,
61+
Input: inputs,
62+
Expected: sel,
63+
}),
64+
}
65+
66+
// in-range invalid selector
67+
if len > 1 {
68+
invalidSel := rng.Uint64N(uint64(len))
69+
for invalidSel == uint64(sel) {
70+
invalidSel = rng.Uint64N(uint64(len))
71+
}
72+
opts = append(opts, test.WithInvalidAssignment(&muxCircuit{
73+
Sel: invalidSel,
74+
Input: inputs,
75+
Expected: sel,
76+
}))
77+
}
78+
79+
assert.CheckCircuit(circuit, opts...)
6580
}
6681

6782
func TestMux(t *testing.T) {
6883
assert := test.NewAssert(t)
6984

70-
assert.CheckCircuit(&muxCircuit{},
71-
test.WithValidAssignment(&muxCircuit{SEL: 2, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 12}),
72-
test.WithValidAssignment(&muxCircuit{SEL: 0, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 10}),
73-
test.WithValidAssignment(&muxCircuit{SEL: 4, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}),
74-
test.WithInvalidAssignment(&muxCircuit{SEL: 5, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}),
75-
test.WithInvalidAssignment(&muxCircuit{SEL: 0, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 21}),
76-
)
85+
for len := 0; len < 9; len++ {
86+
for sel := 0; sel < len+1; sel++ {
87+
assert.Run(func(assert *test.Assert) {
88+
testMux(assert, len+1, sel)
89+
}, fmt.Sprintf("len=%d/sel=%d", len+1, sel))
90+
}
91+
}
7792

7893
assert.CheckCircuit(&ignoredOutputMuxCircuit{},
7994
test.WithValidAssignment(&ignoredOutputMuxCircuit{SEL: 0, I0: 0, I1: 1, I2: 2}),
@@ -82,38 +97,6 @@ func TestMux(t *testing.T) {
8297
test.WithInvalidAssignment(&ignoredOutputMuxCircuit{SEL: -1, I0: 0, I1: 1, I2: 2}),
8398
)
8499

85-
assert.CheckCircuit(&mux2to1Circuit{},
86-
test.WithValidAssignment(&mux2to1Circuit{SEL: 1, I0: 10, I1: 20, OUT: 20}),
87-
test.WithValidAssignment(&mux2to1Circuit{SEL: 0, I0: 10, I1: 20, OUT: 10}),
88-
test.WithInvalidAssignment(&mux2to1Circuit{SEL: 2, I0: 10, I1: 20, OUT: 20}),
89-
)
90-
91-
assert.CheckCircuit(&mux4to1Circuit{},
92-
test.WithValidAssignment(&mux4to1Circuit{
93-
SEL: 3,
94-
In: [4]frontend.Variable{11, 22, 33, 44},
95-
OUT: 44,
96-
}),
97-
test.WithValidAssignment(&mux4to1Circuit{
98-
SEL: 1,
99-
In: [4]frontend.Variable{11, 22, 33, 44},
100-
OUT: 22,
101-
}),
102-
test.WithValidAssignment(&mux4to1Circuit{
103-
SEL: 0,
104-
In: [4]frontend.Variable{11, 22, 33, 44},
105-
OUT: 11,
106-
}),
107-
test.WithInvalidAssignment(&mux4to1Circuit{
108-
SEL: 4,
109-
In: [4]frontend.Variable{11, 22, 33, 44},
110-
OUT: 44,
111-
}),
112-
)
113-
114-
cs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &mux4to1Circuit{})
115-
// (4 - 1) + (2 + 1) + 1 == 7
116-
assert.Equal(7, cs.GetNbConstraints())
117100
}
118101

119102
// Map tests:
@@ -126,7 +109,7 @@ type mapCircuit struct {
126109

127110
func (c *mapCircuit) Define(api frontend.API) error {
128111

129-
out := selector.Map(api, c.SEL,
112+
out := Map(api, c.SEL,
130113
[]frontend.Variable{c.K0, c.K1, c.K2, c.K3},
131114
[]frontend.Variable{c.V0, c.V1, c.V2, c.V3})
132115

@@ -143,7 +126,7 @@ type ignoredOutputMapCircuit struct {
143126

144127
func (c *ignoredOutputMapCircuit) Define(api frontend.API) error {
145128

146-
_ = selector.Map(api, c.SEL,
129+
_ = Map(api, c.SEL,
147130
[]frontend.Variable{c.K0, c.K1},
148131
[]frontend.Variable{c.V0, c.V1})
149132

0 commit comments

Comments
 (0)