Skip to content

Commit 3896f57

Browse files
committed
fix: also sign in v
1 parent 10d28de commit 3896f57

File tree

1 file changed

+23
-30
lines changed

1 file changed

+23
-30
lines changed

std/evmprecompiles/01-ecrecover_test.go

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (c *ecrecoverCircuit) Define(api frontend.API) error {
6060
return nil
6161
}
6262

63-
func testRoutineECRecover(t *testing.T, wantStrict bool, forceLargeS bool) (circ, wit *ecrecoverCircuit, largeS bool) {
63+
func testRoutineECRecover(t *testing.T, forceLargeS bool) (circ, wit *ecrecoverCircuit) {
6464
halfFr := new(big.Int).Sub(fr.Modulus(), big.NewInt(1))
6565
halfFr.Div(halfFr, big.NewInt(2))
6666

@@ -72,24 +72,22 @@ func testRoutineECRecover(t *testing.T, wantStrict bool, forceLargeS bool) (circ
7272
msg := []byte("test")
7373
var r, s *big.Int
7474
var v uint
75-
for {
76-
v, r, s, err = sk.SignForRecover(msg, nil)
77-
if err != nil {
78-
t.Fatal("sign", err)
79-
}
80-
// SignForRecover always returns s < r_mod/2. But in the tests we want
81-
// to check that the circuit fails when s > r_mod/2 in strict mode.
82-
if forceLargeS && s.Cmp(halfFr) <= 0 {
83-
s.Sub(fr.Modulus(), s)
84-
}
85-
86-
if !wantStrict || halfFr.Cmp(s) > 0 {
87-
break
88-
}
75+
v, r, s, err = sk.SignForRecover(msg, nil)
76+
if err != nil {
77+
t.Fatal("sign", err)
8978
}
90-
strict := 0
91-
if wantStrict {
92-
strict = 1
79+
// SignForRecover always returns s < r_mod/2. But in the tests we want
80+
// to check that the circuit fails when s > r_mod/2 in strict mode.
81+
if forceLargeS {
82+
// first we make s large
83+
s.Sub(fr.Modulus(), s)
84+
// but we also have to swap the sign of the recovered public key
85+
v ^= 1
86+
}
87+
88+
strict := 1
89+
if forceLargeS {
90+
strict = 0
9391
}
9492
circuit := ecrecoverCircuit{}
9593
witness := ecrecoverCircuit{
@@ -104,19 +102,19 @@ func testRoutineECRecover(t *testing.T, wantStrict bool, forceLargeS bool) (circ
104102
Y: emulated.ValueOf[emulated.Secp256k1Fp](pk.A.Y),
105103
},
106104
}
107-
return &circuit, &witness, halfFr.Cmp(s) <= 0
105+
return &circuit, &witness
108106
}
109107

110108
func TestECRecoverCircuitShortStrict(t *testing.T) {
111109
assert := test.NewAssert(t)
112-
circuit, witness, _ := testRoutineECRecover(t, true, false)
110+
circuit, witness := testRoutineECRecover(t, false)
113111
err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField())
114112
assert.NoError(err)
115113
}
116114

117115
func TestECRecoverCircuitShortLax(t *testing.T) {
118116
assert := test.NewAssert(t)
119-
circuit, witness, _ := testRoutineECRecover(t, false, false)
117+
circuit, witness := testRoutineECRecover(t, true)
120118
err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField())
121119
assert.NoError(err)
122120
}
@@ -126,20 +124,15 @@ func TestECRecoverCircuitShortMismatch(t *testing.T) {
126124
halfFr := new(big.Int).Sub(fr.Modulus(), big.NewInt(1))
127125
halfFr.Div(halfFr, big.NewInt(2))
128126
var circuit, witness *ecrecoverCircuit
129-
var largeS bool
130-
circuit, witness, largeS = testRoutineECRecover(t, false, true)
131-
if largeS {
132-
witness.Strict = 1
133-
} else {
134-
assert.Fail("test setup failed to produce large S")
135-
}
127+
circuit, witness = testRoutineECRecover(t, true)
128+
witness.Strict = 1
136129
err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField())
137130
assert.Error(err)
138131
}
139132

140133
func TestECRecoverCircuitFull(t *testing.T) {
141134
assert := test.NewAssert(t)
142-
circuit, witness, _ := testRoutineECRecover(t, false, false)
135+
circuit, witness := testRoutineECRecover(t, false)
143136

144137
assert.CheckCircuit(
145138
circuit,
@@ -261,7 +254,7 @@ func TestECRecoverInfinityWoFailure(t *testing.T) {
261254

262255
func TestInvalidFailureTag(t *testing.T) {
263256
assert := test.NewAssert(t)
264-
circuit, witness, _ := testRoutineECRecover(t, false, false)
257+
circuit, witness := testRoutineECRecover(t, false)
265258
witness.IsFailure = 1
266259
err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField())
267260
assert.Error(err)

0 commit comments

Comments
 (0)