Skip to content

Commit 1a1c1ca

Browse files
committed
Use BLS12-318 curve and add multi_verify
1 parent 5055f5b commit 1a1c1ca

File tree

3 files changed

+141
-67
lines changed

3 files changed

+141
-67
lines changed

eth/utils/bls.py

Lines changed: 99 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
Tuple,
55
)
66

7+
from eth_utils import (
8+
big_endian_to_int,
9+
)
710

8-
from py_ecc.optimized_bn128 import ( # NOQA
11+
from py_ecc.optimized_bls12_381 import ( # NOQA
912
G1,
1013
G2,
1114
Z1,
@@ -18,104 +21,128 @@
1821
FQ12,
1922
pairing,
2023
normalize,
21-
field_modulus,
24+
field_modulus as q,
2225
b,
2326
b2,
2427
is_on_curve,
2528
curve_order,
2629
final_exponentiate
2730
)
28-
from eth.utils.blake import blake
31+
from eth_hash.auto import keccak as hash
32+
2933

34+
G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501
35+
qmod = q ** 2 - 1
36+
eighth_roots_of_unity = [
37+
FQ2([1, 1]) ** ((qmod * k) // 8)
38+
for k in range(8)
39+
]
3040

31-
CACHE = {} # type: Dict[bytes, Tuple[FQ2, FQ2, FQ2]]
32-
# 16th root of unity
33-
HEX_ROOT = FQ2([21573744529824266246521972077326577680729363968861965890554801909984373949499,
34-
16854739155576650954933913186877292401521110422362946064090026408937773542853])
3541

42+
#
43+
# Helpers
44+
#
45+
def modular_squareroot(value: int) -> int:
46+
"""
47+
``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``,
48+
and None if this is not possible. In cases where there are two solutions,
49+
the value with higher imaginary component is favored;
50+
if both solutions have equal imaginary component the value with higher real
51+
component is favored.
52+
"""
53+
candidate_squareroot = value ** ((qmod + 8) // 16)
54+
check = candidate_squareroot ** 2 / value
55+
if check in eighth_roots_of_unity[::2]:
56+
x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2]
57+
x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]])
58+
# x2 = - x2
59+
return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2
60+
return None
61+
62+
63+
def hash_to_G2(message: bytes, domain: int) -> Tuple[FQ2, FQ2, FQ2]:
64+
domain_in_bytes = domain.to_bytes(8, 'big')
65+
x1 = big_endian_to_int(hash(domain_in_bytes + b'\x01' + message))
66+
x2 = big_endian_to_int(hash(domain_in_bytes + b'\x02' + message))
67+
x_coordinate = FQ2([x1, x2]) # x1 + x2 * i
68+
while 1:
69+
x_cubed_plus_b2 = x_coordinate ** 3 + FQ2([4, 4])
70+
y_coordinate = modular_squareroot(x_cubed_plus_b2)
71+
if y_coordinate is not None:
72+
break
73+
x_coordinate += FQ2([1, 0]) # Add one until we get a quadratic residue
3674

37-
assert HEX_ROOT ** 8 != FQ2([1, 0])
38-
assert HEX_ROOT ** 16 == FQ2([1, 0])
75+
return multiply(
76+
(x_coordinate, y_coordinate, FQ2([1, 0])),
77+
G2_cofactor
78+
)
3979

4080

81+
#
82+
# G1
83+
#
4184
def compress_G1(pt: Tuple[FQ2, FQ2, FQ2]) -> int:
4285
x, y = normalize(pt)
43-
return x.n + 2**255 * (y.n % 2)
86+
return x.n + 2**383 * (y.n % 2)
4487

4588

4689
def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]:
4790
if p == 0:
4891
return (FQ(1), FQ(1), FQ(0))
49-
x = p % 2**255
50-
y_mod_2 = p // 2**255
51-
y = pow((x**3 + b.n) % field_modulus, (field_modulus + 1) // 4, field_modulus)
52-
assert pow(y, 2, field_modulus) == (x**3 + b.n) % field_modulus
92+
x = p % 2**383
93+
y_mod_2 = p // 2**383
94+
y = pow((x**3 + b.n) % q, (q + 1) // 4, q)
95+
assert pow(y, 2, q) == (x**3 + b.n) % q
5396
if y % 2 != y_mod_2:
54-
y = field_modulus - y
97+
y = q - y
5598
return (FQ(x), FQ(y), FQ(1))
5699

57100

58-
def sqrt_fq2(x: FQ2) -> FQ2:
59-
y = x ** ((field_modulus ** 2 + 15) // 32)
60-
while y**2 != x:
61-
y *= HEX_ROOT
62-
return y
63-
64-
65-
def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]:
66-
"""
67-
WARNING: this function has not been standardized yet.
68-
"""
69-
if m in CACHE:
70-
return CACHE[m]
71-
k2 = m
72-
while 1:
73-
k1 = blake(k2)
74-
k2 = blake(k1)
75-
x1 = int.from_bytes(k1, 'big') % field_modulus
76-
x2 = int.from_bytes(k2, 'big') % field_modulus
77-
x = FQ2([x1, x2])
78-
xcb = x**3 + b2
79-
if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1, 0]):
80-
break
81-
y = sqrt_fq2(xcb)
82-
o = multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order)
83-
CACHE[m] = o
84-
return o
85-
86-
101+
#
102+
# G2
103+
#
87104
def compress_G2(pt: Tuple[FQ2, FQ2, FQ2]) -> Tuple[int, int]:
88105
assert is_on_curve(pt, b2)
89106
x, y = normalize(pt)
90-
return (x.coeffs[0] + 2**255 * (y.coeffs[0] % 2), x.coeffs[1])
107+
return (x.coeffs[0] + 2**383 * (y.coeffs[0] % 2), x.coeffs[1])
91108

92109

93110
def decompress_G2(p: bytes) -> Tuple[FQ2, FQ2, FQ2]:
94-
x1 = p[0] % 2**255
95-
y1_mod_2 = p[0] // 2**255
111+
x1 = p[0] % 2**383
112+
y1_mod_2 = p[0] // 2**383
96113
x2 = p[1]
97114
x = FQ2([x1, x2])
98115
if x == FQ2([0, 0]):
99116
return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0])
100-
y = sqrt_fq2(x**3 + b2)
117+
y = modular_squareroot(x**3 + b2)
101118
if y.coeffs[0] % 2 != y1_mod_2:
102119
y = y * -1
103-
assert is_on_curve((x, y, FQ2([1, 0])), b2)
120+
104121
return x, y, FQ2([1, 0])
105122

106123

107-
def sign(m: bytes, k: int) -> Tuple[int, int]:
108-
return compress_G2(multiply(hash_to_G2(m), k))
124+
#
125+
# APIs
126+
#
127+
def sign(message: bytes,
128+
privkey: int,
129+
domain: int) -> Tuple[int, int]:
130+
return compress_G2(
131+
multiply(
132+
hash_to_G2(message, domain),
133+
privkey
134+
)
135+
)
109136

110137

111138
def privtopub(k: int) -> int:
112139
return compress_G1(multiply(G1, k))
113140

114141

115-
def verify(m: bytes, pub: int, sig: bytes) -> bool:
142+
def verify(m: bytes, pub: int, sig: bytes, domain: int) -> bool:
116143
final_exponentiation = final_exponentiate(
117144
pairing(decompress_G2(sig), G1, False) *
118-
pairing(hash_to_G2(m), neg(decompress_G1(pub)), False)
145+
pairing(hash_to_G2(m, domain), neg(decompress_G1(pub)), False)
119146
)
120147
return final_exponentiation == FQ12.one()
121148

@@ -132,3 +159,22 @@ def aggregate_pubs(pubs: Iterable[int]) -> int:
132159
for p in pubs:
133160
o = add(o, decompress_G1(p))
134161
return compress_G1(o)
162+
163+
164+
def multi_verify(pubs, msgs, sig, domain):
165+
len_msgs = len(msgs)
166+
assert len(pubs) == len_msgs
167+
168+
o = FQ12([1] + [0] * 11)
169+
for m in set(msgs):
170+
# aggregate the pubs
171+
group_pub = Z1
172+
for i in range(len_msgs):
173+
if msgs[i] == m:
174+
group_pub = add(group_pub, decompress_G1(pubs[i]))
175+
176+
o *= pairing(hash_to_G2(m, domain), group_pub, False)
177+
o *= pairing(decompress_G2(sig), neg(G1), False)
178+
179+
final_exponentiation = final_exponentiate(o)
180+
return final_exponentiation == FQ12.one()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"eth-utils>=1.3.0b0,<2.0.0",
1414
"lru-dict>=1.1.6",
1515
"mypy_extensions>=0.4.1,<1.0.0",
16-
"py-ecc==1.4.3",
16+
"py-ecc>=1.4.6,<2.0.0",
1717
"pyethash>=0.1.27,<1.0.0",
1818
"rlp>=1.0.3,<2.0.0",
1919
"trie>=1.3.5,<2.0.0",

tests/core/bls-utils/test_bls.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22

3-
pytest.importorskip('eth.utils.bls') # noqa E402
43
from eth.utils.bls import (
54
G1,
65
G2,
@@ -15,15 +14,11 @@
1514
privtopub,
1615
aggregate_sigs,
1716
aggregate_pubs,
18-
verify
19-
)
20-
21-
from tests.core.helpers import (
22-
greater_equal_python36,
17+
verify,
18+
multi_verify,
2319
)
2420

2521

26-
@greater_equal_python36
2722
@pytest.mark.parametrize(
2823
'privkey',
2924
[
@@ -37,19 +32,20 @@
3732
]
3833
)
3934
def test_bls_core(privkey):
35+
domain = 0
4036
p1 = multiply(G1, privkey)
4137
p2 = multiply(G2, privkey)
4238
msg = str(privkey).encode('utf-8')
43-
msghash = hash_to_G2(msg)
39+
msghash = hash_to_G2(msg, domain=domain)
40+
4441
assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1)
4542
assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2)
4643
assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash)
47-
sig = sign(msg, privkey)
44+
sig = sign(msg, privkey, domain=domain)
4845
pub = privtopub(privkey)
49-
assert verify(msg, pub, sig)
46+
assert verify(msg, pub, sig, domain=domain)
5047

5148

52-
@greater_equal_python36
5349
@pytest.mark.parametrize(
5450
'msg, privkeys',
5551
[
@@ -58,8 +54,40 @@ def test_bls_core(privkey):
5854
]
5955
)
6056
def test_signature_aggregation(msg, privkeys):
61-
sigs = [sign(msg, k) for k in privkeys]
57+
domain = 0
58+
sigs = [sign(msg, k, domain=domain) for k in privkeys]
6259
pubs = [privtopub(k) for k in privkeys]
6360
aggsig = aggregate_sigs(sigs)
6461
aggpub = aggregate_pubs(pubs)
65-
assert verify(msg, aggpub, aggsig)
62+
assert verify(msg, aggpub, aggsig, domain=domain)
63+
64+
65+
@pytest.mark.parametrize(
66+
'msg_1, msg_2, privkeys',
67+
[
68+
(b'cow', b'wow', [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]),
69+
]
70+
)
71+
def test_multi_aggregation(msg_1, msg_2, privkeys):
72+
domain = 0
73+
sigs_1 = [sign(msg_1, k, domain=domain) for k in privkeys]
74+
75+
pubs_1 = [privtopub(k) for k in privkeys]
76+
aggsig_1 = aggregate_sigs(sigs_1)
77+
aggpub_1 = aggregate_pubs(pubs_1)
78+
79+
sigs_2 = [sign(msg_2, k, domain=domain) for k in privkeys]
80+
pubs_2 = [privtopub(k) for k in privkeys]
81+
aggsig_2 = aggregate_sigs(sigs_2)
82+
aggpub_2 = aggregate_pubs(pubs_2)
83+
84+
msgs = [msg_1, msg_2]
85+
pubs = [aggpub_1, aggpub_2]
86+
aggsig = aggregate_sigs([aggsig_1, aggsig_2])
87+
88+
assert multi_verify(
89+
pubs=pubs,
90+
msgs=msgs,
91+
sig=aggsig,
92+
domain=domain,
93+
)

0 commit comments

Comments
 (0)