Skip to content

Commit 4bd796d

Browse files
authored
Merge pull request #1581 from hwwhww/new_bls_2
Update BLS signature module
2 parents 0a3ac38 + 9df7231 commit 4bd796d

File tree

4 files changed

+206
-111
lines changed

4 files changed

+206
-111
lines changed

eth/beacon/aggregation.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,18 @@
66
pipe
77
)
88

9-
from eth_typing import (
10-
Hash32,
11-
)
12-
139
from eth.utils import bls
1410
from eth.utils.bitfield import (
1511
set_voted,
1612
)
17-
from eth.beacon.utils.hash import hash_
18-
1913

20-
def create_signing_message(slot: int,
21-
parent_hashes: Iterable[Hash32],
22-
shard_id: int,
23-
shard_block_hash: Hash32,
24-
justified_slot: int) -> bytes:
25-
"""
26-
Return the signining message for attesting.
27-
"""
28-
# TODO: Will be updated with SSZ encoded attestation.
29-
return hash_(
30-
slot.to_bytes(8, byteorder='big') +
31-
b''.join(parent_hashes) +
32-
shard_id.to_bytes(2, byteorder='big') +
33-
shard_block_hash +
34-
justified_slot.to_bytes(8, 'big')
35-
)
14+
from eth.beacon.enums import SignatureDomain
3615

3716

3817
def verify_votes(
3918
message: bytes,
40-
votes: Iterable[Tuple[int, bytes, int]]) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]:
19+
votes: Iterable[Tuple[int, bytes, int]],
20+
domain: SignatureDomain) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]:
4121
"""
4222
Verify the given votes.
4323
@@ -47,7 +27,7 @@ def verify_votes(
4727
(sig, committee_index)
4828
for (committee_index, sig, public_key)
4929
in votes
50-
if bls.verify(message, public_key, sig)
30+
if bls.verify(message, public_key, sig, domain)
5131
)
5232
try:
5333
sigs, committee_indices = zip(*sigs_with_committe_info)
@@ -75,4 +55,4 @@ def aggregate_votes(bitfield: bytes,
7555
)
7656
)
7757

78-
return bitfield, bls.aggregate_sigs(sigs)
58+
return bitfield, bls.aggregate_signatures(sigs)

eth/utils/bls.py

Lines changed: 145 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from typing import ( # noqa: F401
22
Dict,
3-
Iterable,
3+
Sequence,
44
Tuple,
55
Union,
66
)
77

8+
from eth_utils import (
9+
big_endian_to_int,
10+
ValidationError,
11+
)
812

9-
from py_ecc.optimized_bn128 import ( # NOQA
13+
from py_ecc.optimized_bls12_381 import ( # NOQA
1014
G1,
1115
G2,
1216
Z1,
@@ -20,124 +24,200 @@
2024
FQP,
2125
pairing,
2226
normalize,
23-
field_modulus,
27+
field_modulus as q,
2428
b,
2529
b2,
2630
is_on_curve,
2731
curve_order,
2832
final_exponentiate
2933
)
3034
from eth.beacon.utils.hash import hash_
31-
from eth.utils.bn128 import (
32-
FQP_point_to_FQ2_point,
33-
)
34-
3535

36-
CACHE = {} # type: Dict[bytes, Tuple[FQ2, FQ2, FQ2]]
37-
# 16th root of unity
38-
HEX_ROOT = FQ2([21573744529824266246521972077326577680729363968861965890554801909984373949499,
39-
16854739155576650954933913186877292401521110422362946064090026408937773542853])
4036

37+
G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501
38+
FQ2_order = q ** 2 - 1
39+
eighth_roots_of_unity = [
40+
FQ2([1, 1]) ** ((FQ2_order * k) // 8)
41+
for k in range(8)
42+
]
4143

42-
assert HEX_ROOT ** 8 != FQ2([1, 0])
43-
assert HEX_ROOT ** 16 == FQ2([1, 0])
4444

45-
46-
def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int:
47-
x, y = normalize(pt)
48-
return x.n + 2**255 * (y.n % 2)
45+
#
46+
# Helpers
47+
#
48+
def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]:
49+
"""
50+
Transform FQP to FQ2 for type hinting.
51+
"""
52+
return (
53+
FQ2(pt[0].coeffs),
54+
FQ2(pt[1].coeffs),
55+
FQ2(pt[2].coeffs),
56+
)
4957

5058

51-
def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]:
52-
if p == 0:
53-
return (FQ(1), FQ(1), FQ(0))
54-
x = p % 2**255
55-
y_mod_2 = p // 2**255
56-
y = pow((x**3 + b.n) % field_modulus, (field_modulus + 1) // 4, field_modulus)
57-
assert pow(y, 2, field_modulus) == (x**3 + b.n) % field_modulus
58-
if y % 2 != y_mod_2:
59-
y = field_modulus - y
60-
return (FQ(x), FQ(y), FQ(1))
59+
def modular_squareroot(value: int) -> FQP:
60+
"""
61+
``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``,
62+
and None if this is not possible. In cases where there are two solutions,
63+
the value with higher imaginary component is favored;
64+
if both solutions have equal imaginary component the value with higher real
65+
component is favored.
66+
"""
67+
candidate_squareroot = value ** ((FQ2_order + 8) // 16)
68+
check = candidate_squareroot ** 2 / value
69+
if check in eighth_roots_of_unity[::2]:
70+
x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2]
71+
x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]]) # x2 = -x1
72+
return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2
73+
return None
6174

6275

63-
def sqrt_fq2(x: FQP) -> FQ2:
64-
y = x ** ((field_modulus ** 2 + 15) // 32)
65-
while y**2 != x:
66-
y *= HEX_ROOT
67-
return FQ2(y.coeffs)
76+
def hash_to_G2(message: bytes, domain: int) -> Tuple[FQ2, FQ2, FQ2]:
77+
domain_in_bytes = domain.to_bytes(8, 'big')
6878

79+
# Initial candidate x coordinate
80+
x_re = big_endian_to_int(hash_(domain_in_bytes + b'\x01' + message))
81+
x_im = big_endian_to_int(hash_(domain_in_bytes + b'\x02' + message))
82+
x_coordinate = FQ2([x_re, x_im]) # x_re + x_im * i
6983

70-
def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]:
71-
"""
72-
WARNING: this function has not been standardized yet.
73-
"""
74-
if m in CACHE:
75-
return CACHE[m]
76-
k2 = m
84+
# Test candidate y coordinates until a one is found
7785
while 1:
78-
k1 = hash_(k2)
79-
k2 = hash_(k1)
80-
x1 = int.from_bytes(k1, 'big') % field_modulus
81-
x2 = int.from_bytes(k2, 'big') % field_modulus
82-
x = FQ2([x1, x2])
83-
xcb = x**3 + b2
84-
if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1, 0]):
86+
y_coordinate_squared = x_coordinate ** 3 + FQ2([4, 4]) # The curve is y^2 = x^3 + 4(i + 1)
87+
y_coordinate = modular_squareroot(y_coordinate_squared)
88+
if y_coordinate is not None: # Check if quadratic residue found
8589
break
86-
y = sqrt_fq2(xcb)
90+
x_coordinate += FQ2([1, 0]) # Add 1 and try again
8791

88-
o = FQP_point_to_FQ2_point(multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order))
89-
CACHE[m] = o
90-
return o
92+
return multiply(
93+
(x_coordinate, y_coordinate, FQ2([1, 0])),
94+
G2_cofactor
95+
)
9196

9297

98+
#
99+
# G1
100+
#
101+
def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int:
102+
x, y = normalize(pt)
103+
return x.n + 2**383 * (y.n % 2)
104+
105+
106+
def decompress_G1(pt: int) -> Tuple[FQ, FQ, FQ]:
107+
if pt == 0:
108+
return (FQ(1), FQ(1), FQ(0))
109+
x = pt % 2**383
110+
y_mod_2 = pt // 2**383
111+
y = pow((x**3 + b.n) % q, (q + 1) // 4, q)
112+
113+
if pow(y, 2, q) != (x**3 + b.n) % q:
114+
raise ValueError(
115+
"he given point is not on G1: y**2 = x**3 + b"
116+
)
117+
if y % 2 != y_mod_2:
118+
y = q - y
119+
return (FQ(x), FQ(y), FQ(1))
120+
121+
122+
#
123+
# G2
124+
#
93125
def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]:
94-
assert is_on_curve(pt, b2)
126+
if not is_on_curve(pt, b2):
127+
raise ValueError(
128+
"The given point is not on the twisted curve over FQ**2"
129+
)
95130
x, y = normalize(pt)
96131
return (
97-
int(x.coeffs[0] + 2**255 * (y.coeffs[0] % 2)),
132+
int(x.coeffs[0] + 2**383 * (y.coeffs[0] % 2)),
98133
int(x.coeffs[1])
99134
)
100135

101136

102137
def decompress_G2(p: bytes) -> Tuple[FQP, FQP, FQP]:
103-
x1 = p[0] % 2**255
104-
y1_mod_2 = p[0] // 2**255
138+
x1 = p[0] % 2**383
139+
y1_mod_2 = p[0] // 2**383
105140
x2 = p[1]
106141
x = FQ2([x1, x2])
107142
if x == FQ2([0, 0]):
108143
return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0])
109-
y = sqrt_fq2(x**3 + b2)
144+
y = modular_squareroot(x**3 + b2)
110145
if y.coeffs[0] % 2 != y1_mod_2:
111146
y = FQ2((y * -1).coeffs)
112-
assert is_on_curve((x, y, FQ2([1, 0])), b2)
147+
if not is_on_curve((x, y, FQ2([1, 0])), b2):
148+
raise ValueError(
149+
"The given point is not on the twisted curve over FQ**2"
150+
)
113151
return x, y, FQ2([1, 0])
114152

115153

116-
def sign(m: bytes, k: int) -> Tuple[int, int]:
117-
return compress_G2(multiply(hash_to_G2(m), k))
154+
#
155+
# APIs
156+
#
157+
def sign(message: bytes,
158+
privkey: int,
159+
domain: int) -> Tuple[int, int]:
160+
return compress_G2(
161+
multiply(
162+
hash_to_G2(message, domain),
163+
privkey
164+
)
165+
)
118166

119167

120168
def privtopub(k: int) -> int:
121169
return compress_G1(multiply(G1, k))
122170

123171

124-
def verify(m: bytes, pub: int, sig: bytes) -> bool:
172+
def verify(message: bytes, pubkey: int, signature: bytes, domain: int) -> bool:
125173
final_exponentiation = final_exponentiate(
126-
pairing(FQP_point_to_FQ2_point(decompress_G2(sig)), G1, False) *
127-
pairing(FQP_point_to_FQ2_point(hash_to_G2(m)), neg(decompress_G1(pub)), False)
174+
pairing(FQP_point_to_FQ2_point(decompress_G2(signature)), G1, False) *
175+
pairing(
176+
FQP_point_to_FQ2_point(hash_to_G2(message, domain)),
177+
neg(decompress_G1(pubkey)),
178+
False
179+
)
128180
)
129181
return final_exponentiation == FQ12.one()
130182

131183

132-
def aggregate_sigs(sigs: Iterable[bytes]) -> Tuple[int, int]:
184+
def aggregate_signatures(signatures: Sequence[bytes]) -> Tuple[int, int]:
133185
o = Z2
134-
for s in sigs:
186+
for s in signatures:
135187
o = FQP_point_to_FQ2_point(add(o, decompress_G2(s)))
136188
return compress_G2(o)
137189

138190

139-
def aggregate_pubs(pubs: Iterable[int]) -> int:
191+
def aggregate_pubkeys(pubkeys: Sequence[int]) -> int:
140192
o = Z1
141-
for p in pubs:
193+
for p in pubkeys:
142194
o = add(o, decompress_G1(p))
143195
return compress_G1(o)
196+
197+
198+
def verify_multiple(pubkeys: Sequence[int],
199+
messages: Sequence[bytes],
200+
signature: bytes,
201+
domain: int) -> bool:
202+
len_msgs = len(messages)
203+
204+
if len(pubkeys) != len_msgs:
205+
raise ValidationError(
206+
"len(pubkeys) (%s) should be equal to len(messages) (%s)" % (
207+
len(pubkeys), len_msgs
208+
)
209+
)
210+
211+
o = FQ12([1] + [0] * 11)
212+
for m_pubs in set(messages):
213+
# aggregate the pubs
214+
group_pub = Z1
215+
for i in range(len_msgs):
216+
if messages[i] == m_pubs:
217+
group_pub = add(group_pub, decompress_G1(pubkeys[i]))
218+
219+
o *= pairing(hash_to_G2(m_pubs, domain), group_pub, False)
220+
o *= pairing(decompress_G2(signature), neg(G1), False)
221+
222+
final_exponentiation = final_exponentiate(o)
223+
return final_exponentiation == FQ12.one()

tests/beacon/test_aggregation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,23 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys):
3131
bit_count = 10
3232
pre_bitfield = get_empty_bitfield(bit_count)
3333
pre_sigs = ()
34+
domain = 0
3435

3536
random_votes = random.sample(range(bit_count), votes_count)
3637
message = b'hello'
3738

3839
# Get votes: (committee_index, sig, public_key)
3940
votes = [
40-
(committee_index, bls.sign(message, privkeys[committee_index]), pubkeys[committee_index])
41+
(
42+
committee_index,
43+
bls.sign(message, privkeys[committee_index], domain),
44+
pubkeys[committee_index],
45+
)
4146
for committee_index in random_votes
4247
]
4348

4449
# Verify
45-
sigs, committee_indices = verify_votes(message, votes)
50+
sigs, committee_indices = verify_votes(message, votes, domain)
4651

4752
# Aggregate the votes
4853
bitfield, sigs = aggregate_votes(
@@ -64,5 +69,5 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys):
6469
]
6570
assert len(voted_index) == len(votes)
6671

67-
aggregated_pubs = bls.aggregate_pubs(pubs)
68-
assert bls.verify(message, aggregated_pubs, sigs)
72+
aggregated_pubs = bls.aggregate_pubkeys(pubs)
73+
assert bls.verify(message, aggregated_pubs, sigs, domain)

0 commit comments

Comments
 (0)