Skip to content

Commit 8b9caa1

Browse files
committed
Use BLS12-318 curve and add multi_verify
1 parent 7e3b313 commit 8b9caa1

File tree

2 files changed

+150
-69
lines changed

2 files changed

+150
-69
lines changed

eth/utils/bls.py

Lines changed: 109 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
Union,
66
)
77

8+
from eth_utils import (
9+
big_endian_to_int,
10+
)
811

9-
from py_ecc.optimized_bn128 import ( # NOQA
12+
from py_ecc.optimized_bls12_381 import ( # NOQA
1013
G1,
1114
G2,
1215
Z1,
@@ -20,111 +23,142 @@
2023
FQP,
2124
pairing,
2225
normalize,
23-
field_modulus,
26+
field_modulus as q,
2427
b,
2528
b2,
2629
is_on_curve,
2730
curve_order,
2831
final_exponentiate
2932
)
30-
from eth.utils.blake import blake
31-
from eth.utils.bn128 import (
32-
FQP_point_to_FQ2_point,
33-
)
33+
from eth_hash.auto import keccak as hash
3434

3535

36-
CACHE = {} # type: Dict[bytes, Tuple[FQ2, FQ2, FQ2]]
37-
# 16th root of unity
38-
HEX_ROOT = FQ2([21573744529824266246521972077326577680729363968861965890554801909984373949499,
39-
16854739155576650954933913186877292401521110422362946064090026408937773542853])
36+
G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501
37+
qmod = q ** 2 - 1
38+
eighth_roots_of_unity = [
39+
FQ2([1, 1]) ** ((qmod * k) // 8)
40+
for k in range(8)
41+
]
4042

4143

42-
assert HEX_ROOT ** 8 != FQ2([1, 0])
43-
assert HEX_ROOT ** 16 == FQ2([1, 0])
44+
#
45+
# Helpers
46+
#
47+
def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]:
48+
"""
49+
Transform FQP to FQ2 for type hinting.
50+
"""
51+
return (
52+
FQ2(pt[0].coeffs),
53+
FQ2(pt[1].coeffs),
54+
FQ2(pt[2].coeffs),
55+
)
4456

4557

58+
def modular_squareroot(value: int) -> int:
59+
"""
60+
``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``,
61+
and None if this is not possible. In cases where there are two solutions,
62+
the value with higher imaginary component is favored;
63+
if both solutions have equal imaginary component the value with higher real
64+
component is favored.
65+
"""
66+
candidate_squareroot = value ** ((qmod + 8) // 16)
67+
check = candidate_squareroot ** 2 / value
68+
if check in eighth_roots_of_unity[::2]:
69+
x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2]
70+
x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]])
71+
# x2 = - x2
72+
return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2
73+
return None
74+
75+
76+
def hash_to_G2(message: bytes, domain: int) -> Tuple[FQ2, FQ2, FQ2]:
77+
domain_in_bytes = domain.to_bytes(8, 'big')
78+
x1 = big_endian_to_int(hash(domain_in_bytes + b'\x01' + message))
79+
x2 = big_endian_to_int(hash(domain_in_bytes + b'\x02' + message))
80+
x_coordinate = FQ2([x1, x2]) # x1 + x2 * i
81+
while 1:
82+
x_cubed_plus_b2 = x_coordinate ** 3 + FQ2([4, 4])
83+
y_coordinate = modular_squareroot(x_cubed_plus_b2)
84+
if y_coordinate is not None:
85+
break
86+
x_coordinate += FQ2([1, 0]) # Add one until we get a quadratic residue
87+
88+
return multiply(
89+
(x_coordinate, y_coordinate, FQ2([1, 0])),
90+
G2_cofactor
91+
)
92+
93+
94+
#
95+
# G1
96+
#
4697
def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int:
4798
x, y = normalize(pt)
48-
return x.n + 2**255 * (y.n % 2)
99+
return x.n + 2**383 * (y.n % 2)
49100

50101

51102
def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]:
52103
if p == 0:
53104
return (FQ(1), FQ(1), FQ(0))
54-
x = p % 2**255
55-
y_mod_2 = p // 2**255
56-
y = pow((x**3 + b.n) % field_modulus, (field_modulus + 1) // 4, field_modulus)
57-
assert pow(y, 2, field_modulus) == (x**3 + b.n) % field_modulus
105+
x = p % 2**383
106+
y_mod_2 = p // 2**383
107+
y = pow((x**3 + b.n) % q, (q + 1) // 4, q)
108+
assert pow(y, 2, q) == (x**3 + b.n) % q
58109
if y % 2 != y_mod_2:
59-
y = field_modulus - y
110+
y = q - y
60111
return (FQ(x), FQ(y), FQ(1))
61112

62113

63-
def sqrt_fq2(x: FQP) -> FQ2:
64-
y = x ** ((field_modulus ** 2 + 15) // 32)
65-
while y**2 != x:
66-
y *= HEX_ROOT
67-
return FQ2(y.coeffs)
68-
69-
70-
def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]:
71-
"""
72-
WARNING: this function has not been standardized yet.
73-
"""
74-
if m in CACHE:
75-
return CACHE[m]
76-
k2 = m
77-
while 1:
78-
k1 = blake(k2)
79-
k2 = blake(k1)
80-
x1 = int.from_bytes(k1, 'big') % field_modulus
81-
x2 = int.from_bytes(k2, 'big') % field_modulus
82-
x = FQ2([x1, x2])
83-
xcb = x**3 + b2
84-
if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1, 0]):
85-
break
86-
y = sqrt_fq2(xcb)
87-
88-
o = FQP_point_to_FQ2_point(multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order))
89-
CACHE[m] = o
90-
return o
91-
92-
114+
#
115+
# G2
116+
#
93117
def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]:
94118
assert is_on_curve(pt, b2)
95119
x, y = normalize(pt)
96120
return (
97-
int(x.coeffs[0] + 2**255 * (y.coeffs[0] % 2)),
121+
int(x.coeffs[0] + 2**383 * (y.coeffs[0] % 2)),
98122
int(x.coeffs[1])
99123
)
100124

101125

102126
def decompress_G2(p: bytes) -> Tuple[FQP, FQP, FQP]:
103-
x1 = p[0] % 2**255
104-
y1_mod_2 = p[0] // 2**255
127+
x1 = p[0] % 2**383
128+
y1_mod_2 = p[0] // 2**383
105129
x2 = p[1]
106130
x = FQ2([x1, x2])
107131
if x == FQ2([0, 0]):
108132
return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0])
109-
y = sqrt_fq2(x**3 + b2)
133+
y = modular_squareroot(x**3 + b2)
110134
if y.coeffs[0] % 2 != y1_mod_2:
111135
y = FQ2((y * -1).coeffs)
112136
assert is_on_curve((x, y, FQ2([1, 0])), b2)
113137
return x, y, FQ2([1, 0])
114138

115139

116-
def sign(m: bytes, k: int) -> Tuple[int, int]:
117-
return compress_G2(multiply(hash_to_G2(m), k))
140+
#
141+
# APIs
142+
#
143+
def sign(message: bytes,
144+
privkey: int,
145+
domain: int) -> Tuple[int, int]:
146+
return compress_G2(
147+
multiply(
148+
hash_to_G2(message, domain),
149+
privkey
150+
)
151+
)
118152

119153

120154
def privtopub(k: int) -> int:
121155
return compress_G1(multiply(G1, k))
122156

123157

124-
def verify(m: bytes, pub: int, sig: bytes) -> bool:
158+
def verify(m: bytes, pub: int, sig: bytes, domain: int) -> bool:
125159
final_exponentiation = final_exponentiate(
126160
pairing(FQP_point_to_FQ2_point(decompress_G2(sig)), G1, False) *
127-
pairing(FQP_point_to_FQ2_point(hash_to_G2(m)), neg(decompress_G1(pub)), False)
161+
pairing(FQP_point_to_FQ2_point(hash_to_G2(m, domain)), neg(decompress_G1(pub)), False)
128162
)
129163
return final_exponentiation == FQ12.one()
130164

@@ -141,3 +175,22 @@ def aggregate_pubs(pubs: Iterable[int]) -> int:
141175
for p in pubs:
142176
o = add(o, decompress_G1(p))
143177
return compress_G1(o)
178+
179+
180+
def multi_verify(pubs, msgs, sig, domain):
181+
len_msgs = len(msgs)
182+
assert len(pubs) == len_msgs
183+
184+
o = FQ12([1] + [0] * 11)
185+
for m in set(msgs):
186+
# aggregate the pubs
187+
group_pub = Z1
188+
for i in range(len_msgs):
189+
if msgs[i] == m:
190+
group_pub = add(group_pub, decompress_G1(pubs[i]))
191+
192+
o *= pairing(hash_to_G2(m, domain), group_pub, False)
193+
o *= pairing(decompress_G2(sig), neg(G1), False)
194+
195+
final_exponentiation = final_exponentiate(o)
196+
return final_exponentiation == FQ12.one()

tests/core/bls-utils/test_bls.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22

3-
pytest.importorskip('eth.utils.bls') # noqa E402
43
from eth.utils.bls import (
54
G1,
65
G2,
@@ -15,15 +14,11 @@
1514
privtopub,
1615
aggregate_sigs,
1716
aggregate_pubs,
18-
verify
19-
)
20-
21-
from tests.core.helpers import (
22-
greater_equal_python36,
17+
verify,
18+
multi_verify,
2319
)
2420

2521

26-
@greater_equal_python36
2722
@pytest.mark.parametrize(
2823
'privkey',
2924
[
@@ -37,19 +32,20 @@
3732
]
3833
)
3934
def test_bls_core(privkey):
35+
domain = 0
4036
p1 = multiply(G1, privkey)
4137
p2 = multiply(G2, privkey)
4238
msg = str(privkey).encode('utf-8')
43-
msghash = hash_to_G2(msg)
39+
msghash = hash_to_G2(msg, domain=domain)
40+
4441
assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1)
4542
assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2)
4643
assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash)
47-
sig = sign(msg, privkey)
44+
sig = sign(msg, privkey, domain=domain)
4845
pub = privtopub(privkey)
49-
assert verify(msg, pub, sig)
46+
assert verify(msg, pub, sig, domain=domain)
5047

5148

52-
@greater_equal_python36
5349
@pytest.mark.parametrize(
5450
'msg, privkeys',
5551
[
@@ -58,8 +54,40 @@ def test_bls_core(privkey):
5854
]
5955
)
6056
def test_signature_aggregation(msg, privkeys):
61-
sigs = [sign(msg, k) for k in privkeys]
57+
domain = 0
58+
sigs = [sign(msg, k, domain=domain) for k in privkeys]
6259
pubs = [privtopub(k) for k in privkeys]
6360
aggsig = aggregate_sigs(sigs)
6461
aggpub = aggregate_pubs(pubs)
65-
assert verify(msg, aggpub, aggsig)
62+
assert verify(msg, aggpub, aggsig, domain=domain)
63+
64+
65+
@pytest.mark.parametrize(
66+
'msg_1, msg_2, privkeys',
67+
[
68+
(b'cow', b'wow', [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]),
69+
]
70+
)
71+
def test_multi_aggregation(msg_1, msg_2, privkeys):
72+
domain = 0
73+
sigs_1 = [sign(msg_1, k, domain=domain) for k in privkeys]
74+
75+
pubs_1 = [privtopub(k) for k in privkeys]
76+
aggsig_1 = aggregate_sigs(sigs_1)
77+
aggpub_1 = aggregate_pubs(pubs_1)
78+
79+
sigs_2 = [sign(msg_2, k, domain=domain) for k in privkeys]
80+
pubs_2 = [privtopub(k) for k in privkeys]
81+
aggsig_2 = aggregate_sigs(sigs_2)
82+
aggpub_2 = aggregate_pubs(pubs_2)
83+
84+
msgs = [msg_1, msg_2]
85+
pubs = [aggpub_1, aggpub_2]
86+
aggsig = aggregate_sigs([aggsig_1, aggsig_2])
87+
88+
assert multi_verify(
89+
pubs=pubs,
90+
msgs=msgs,
91+
sig=aggsig,
92+
domain=domain,
93+
)

0 commit comments

Comments
 (0)