diff --git a/eth/beacon/aggregation.py b/eth/beacon/aggregation.py index 0def06cb8a..9f9f6bae65 100644 --- a/eth/beacon/aggregation.py +++ b/eth/beacon/aggregation.py @@ -6,38 +6,18 @@ pipe ) -from eth_typing import ( - Hash32, -) - from eth.utils import bls from eth.utils.bitfield import ( set_voted, ) -from eth.beacon.utils.hash import hash_ - -def create_signing_message(slot: int, - parent_hashes: Iterable[Hash32], - shard_id: int, - shard_block_hash: Hash32, - justified_slot: int) -> bytes: - """ - Return the signining message for attesting. - """ - # TODO: Will be updated with SSZ encoded attestation. - return hash_( - slot.to_bytes(8, byteorder='big') + - b''.join(parent_hashes) + - shard_id.to_bytes(2, byteorder='big') + - shard_block_hash + - justified_slot.to_bytes(8, 'big') - ) +from eth.beacon.enums import SignatureDomain def verify_votes( message: bytes, - votes: Iterable[Tuple[int, bytes, int]]) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]: + votes: Iterable[Tuple[int, bytes, int]], + domain: SignatureDomain) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]: """ Verify the given votes. @@ -47,7 +27,7 @@ def verify_votes( (sig, committee_index) for (committee_index, sig, public_key) in votes - if bls.verify(message, public_key, sig) + if bls.verify(message, public_key, sig, domain) ) try: sigs, committee_indices = zip(*sigs_with_committe_info) @@ -75,4 +55,4 @@ def aggregate_votes(bitfield: bytes, ) ) - return bitfield, bls.aggregate_sigs(sigs) + return bitfield, bls.aggregate_signatures(sigs) diff --git a/eth/utils/bls.py b/eth/utils/bls.py index 0ea8008ef9..4b006d549f 100644 --- a/eth/utils/bls.py +++ b/eth/utils/bls.py @@ -1,12 +1,16 @@ from typing import ( # noqa: F401 Dict, - Iterable, + Sequence, Tuple, Union, ) +from eth_utils import ( + big_endian_to_int, + ValidationError, +) -from py_ecc.optimized_bn128 import ( # NOQA +from py_ecc.optimized_bls12_381 import ( # NOQA G1, G2, Z1, @@ -20,7 +24,7 @@ FQP, pairing, normalize, - field_modulus, + field_modulus as q, b, b2, is_on_curve, @@ -28,116 +32,192 @@ final_exponentiate ) from eth.beacon.utils.hash import hash_ -from eth.utils.bn128 import ( - FQP_point_to_FQ2_point, -) - -CACHE = {} # type: Dict[bytes, Tuple[FQ2, FQ2, FQ2]] -# 16th root of unity -HEX_ROOT = FQ2([21573744529824266246521972077326577680729363968861965890554801909984373949499, - 16854739155576650954933913186877292401521110422362946064090026408937773542853]) +G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501 +FQ2_order = q ** 2 - 1 +eighth_roots_of_unity = [ + FQ2([1, 1]) ** ((FQ2_order * k) // 8) + for k in range(8) +] -assert HEX_ROOT ** 8 != FQ2([1, 0]) -assert HEX_ROOT ** 16 == FQ2([1, 0]) - -def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int: - x, y = normalize(pt) - return x.n + 2**255 * (y.n % 2) +# +# Helpers +# +def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]: + """ + Transform FQP to FQ2 for type hinting. + """ + return ( + FQ2(pt[0].coeffs), + FQ2(pt[1].coeffs), + FQ2(pt[2].coeffs), + ) -def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]: - if p == 0: - return (FQ(1), FQ(1), FQ(0)) - x = p % 2**255 - y_mod_2 = p // 2**255 - y = pow((x**3 + b.n) % field_modulus, (field_modulus + 1) // 4, field_modulus) - assert pow(y, 2, field_modulus) == (x**3 + b.n) % field_modulus - if y % 2 != y_mod_2: - y = field_modulus - y - return (FQ(x), FQ(y), FQ(1)) +def modular_squareroot(value: int) -> FQP: + """ + ``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``, + and None if this is not possible. In cases where there are two solutions, + the value with higher imaginary component is favored; + if both solutions have equal imaginary component the value with higher real + component is favored. + """ + candidate_squareroot = value ** ((FQ2_order + 8) // 16) + check = candidate_squareroot ** 2 / value + if check in eighth_roots_of_unity[::2]: + x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2] + x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]]) # x2 = -x1 + return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2 + return None -def sqrt_fq2(x: FQP) -> FQ2: - y = x ** ((field_modulus ** 2 + 15) // 32) - while y**2 != x: - y *= HEX_ROOT - return FQ2(y.coeffs) +def hash_to_G2(message: bytes, domain: int) -> Tuple[FQ2, FQ2, FQ2]: + domain_in_bytes = domain.to_bytes(8, 'big') + # Initial candidate x coordinate + x_re = big_endian_to_int(hash_(domain_in_bytes + b'\x01' + message)) + x_im = big_endian_to_int(hash_(domain_in_bytes + b'\x02' + message)) + x_coordinate = FQ2([x_re, x_im]) # x_re + x_im * i -def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]: - """ - WARNING: this function has not been standardized yet. - """ - if m in CACHE: - return CACHE[m] - k2 = m + # Test candidate y coordinates until a one is found while 1: - k1 = hash_(k2) - k2 = hash_(k1) - x1 = int.from_bytes(k1, 'big') % field_modulus - x2 = int.from_bytes(k2, 'big') % field_modulus - x = FQ2([x1, x2]) - xcb = x**3 + b2 - if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1, 0]): + y_coordinate_squared = x_coordinate ** 3 + FQ2([4, 4]) # The curve is y^2 = x^3 + 4(i + 1) + y_coordinate = modular_squareroot(y_coordinate_squared) + if y_coordinate is not None: # Check if quadratic residue found break - y = sqrt_fq2(xcb) + x_coordinate += FQ2([1, 0]) # Add 1 and try again - o = FQP_point_to_FQ2_point(multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order)) - CACHE[m] = o - return o + return multiply( + (x_coordinate, y_coordinate, FQ2([1, 0])), + G2_cofactor + ) +# +# G1 +# +def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int: + x, y = normalize(pt) + return x.n + 2**383 * (y.n % 2) + + +def decompress_G1(pt: int) -> Tuple[FQ, FQ, FQ]: + if pt == 0: + return (FQ(1), FQ(1), FQ(0)) + x = pt % 2**383 + y_mod_2 = pt // 2**383 + y = pow((x**3 + b.n) % q, (q + 1) // 4, q) + + if pow(y, 2, q) != (x**3 + b.n) % q: + raise ValueError( + "he given point is not on G1: y**2 = x**3 + b" + ) + if y % 2 != y_mod_2: + y = q - y + return (FQ(x), FQ(y), FQ(1)) + + +# +# G2 +# def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]: - assert is_on_curve(pt, b2) + if not is_on_curve(pt, b2): + raise ValueError( + "The given point is not on the twisted curve over FQ**2" + ) x, y = normalize(pt) return ( - int(x.coeffs[0] + 2**255 * (y.coeffs[0] % 2)), + int(x.coeffs[0] + 2**383 * (y.coeffs[0] % 2)), int(x.coeffs[1]) ) def decompress_G2(p: bytes) -> Tuple[FQP, FQP, FQP]: - x1 = p[0] % 2**255 - y1_mod_2 = p[0] // 2**255 + x1 = p[0] % 2**383 + y1_mod_2 = p[0] // 2**383 x2 = p[1] x = FQ2([x1, x2]) if x == FQ2([0, 0]): return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0]) - y = sqrt_fq2(x**3 + b2) + y = modular_squareroot(x**3 + b2) if y.coeffs[0] % 2 != y1_mod_2: y = FQ2((y * -1).coeffs) - assert is_on_curve((x, y, FQ2([1, 0])), b2) + if not is_on_curve((x, y, FQ2([1, 0])), b2): + raise ValueError( + "The given point is not on the twisted curve over FQ**2" + ) return x, y, FQ2([1, 0]) -def sign(m: bytes, k: int) -> Tuple[int, int]: - return compress_G2(multiply(hash_to_G2(m), k)) +# +# APIs +# +def sign(message: bytes, + privkey: int, + domain: int) -> Tuple[int, int]: + return compress_G2( + multiply( + hash_to_G2(message, domain), + privkey + ) + ) def privtopub(k: int) -> int: return compress_G1(multiply(G1, k)) -def verify(m: bytes, pub: int, sig: bytes) -> bool: +def verify(message: bytes, pubkey: int, signature: bytes, domain: int) -> bool: final_exponentiation = final_exponentiate( - pairing(FQP_point_to_FQ2_point(decompress_G2(sig)), G1, False) * - pairing(FQP_point_to_FQ2_point(hash_to_G2(m)), neg(decompress_G1(pub)), False) + pairing(FQP_point_to_FQ2_point(decompress_G2(signature)), G1, False) * + pairing( + FQP_point_to_FQ2_point(hash_to_G2(message, domain)), + neg(decompress_G1(pubkey)), + False + ) ) return final_exponentiation == FQ12.one() -def aggregate_sigs(sigs: Iterable[bytes]) -> Tuple[int, int]: +def aggregate_signatures(signatures: Sequence[bytes]) -> Tuple[int, int]: o = Z2 - for s in sigs: + for s in signatures: o = FQP_point_to_FQ2_point(add(o, decompress_G2(s))) return compress_G2(o) -def aggregate_pubs(pubs: Iterable[int]) -> int: +def aggregate_pubkeys(pubkeys: Sequence[int]) -> int: o = Z1 - for p in pubs: + for p in pubkeys: o = add(o, decompress_G1(p)) return compress_G1(o) + + +def verify_multiple(pubkeys: Sequence[int], + messages: Sequence[bytes], + signature: bytes, + domain: int) -> bool: + len_msgs = len(messages) + + if len(pubkeys) != len_msgs: + raise ValidationError( + "len(pubkeys) (%s) should be equal to len(messages) (%s)" % ( + len(pubkeys), len_msgs + ) + ) + + o = FQ12([1] + [0] * 11) + for m_pubs in set(messages): + # aggregate the pubs + group_pub = Z1 + for i in range(len_msgs): + if messages[i] == m_pubs: + group_pub = add(group_pub, decompress_G1(pubkeys[i])) + + o *= pairing(hash_to_G2(m_pubs, domain), group_pub, False) + o *= pairing(decompress_G2(signature), neg(G1), False) + + final_exponentiation = final_exponentiate(o) + return final_exponentiation == FQ12.one() diff --git a/tests/beacon/test_aggregation.py b/tests/beacon/test_aggregation.py index f0e05109d6..79c2ee3f84 100644 --- a/tests/beacon/test_aggregation.py +++ b/tests/beacon/test_aggregation.py @@ -31,18 +31,23 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys): bit_count = 10 pre_bitfield = get_empty_bitfield(bit_count) pre_sigs = () + domain = 0 random_votes = random.sample(range(bit_count), votes_count) message = b'hello' # Get votes: (committee_index, sig, public_key) votes = [ - (committee_index, bls.sign(message, privkeys[committee_index]), pubkeys[committee_index]) + ( + committee_index, + bls.sign(message, privkeys[committee_index], domain), + pubkeys[committee_index], + ) for committee_index in random_votes ] # Verify - sigs, committee_indices = verify_votes(message, votes) + sigs, committee_indices = verify_votes(message, votes, domain) # Aggregate the votes bitfield, sigs = aggregate_votes( @@ -64,5 +69,5 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys): ] assert len(voted_index) == len(votes) - aggregated_pubs = bls.aggregate_pubs(pubs) - assert bls.verify(message, aggregated_pubs, sigs) + aggregated_pubs = bls.aggregate_pubkeys(pubs) + assert bls.verify(message, aggregated_pubs, sigs, domain) diff --git a/tests/core/bls-utils/test_bls.py b/tests/core/bls-utils/test_bls.py index b6d76a6b6a..8ed98ede7d 100644 --- a/tests/core/bls-utils/test_bls.py +++ b/tests/core/bls-utils/test_bls.py @@ -1,6 +1,5 @@ import pytest -pytest.importorskip('eth.utils.bls') # noqa E402 from eth.utils.bls import ( G1, G2, @@ -13,17 +12,13 @@ multiply, sign, privtopub, - aggregate_sigs, - aggregate_pubs, - verify -) - -from tests.core.helpers import ( - greater_equal_python36, + aggregate_signatures, + aggregate_pubkeys, + verify, + verify_multiple, ) -@greater_equal_python36 @pytest.mark.parametrize( 'privkey', [ @@ -37,19 +32,20 @@ ] ) def test_bls_core(privkey): + domain = 0 p1 = multiply(G1, privkey) p2 = multiply(G2, privkey) msg = str(privkey).encode('utf-8') - msghash = hash_to_G2(msg) + msghash = hash_to_G2(msg, domain=domain) + assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1) assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2) assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash) - sig = sign(msg, privkey) + sig = sign(msg, privkey, domain=domain) pub = privtopub(privkey) - assert verify(msg, pub, sig) + assert verify(msg, pub, sig, domain=domain) -@greater_equal_python36 @pytest.mark.parametrize( 'msg, privkeys', [ @@ -58,8 +54,42 @@ def test_bls_core(privkey): ] ) def test_signature_aggregation(msg, privkeys): - sigs = [sign(msg, k) for k in privkeys] + domain = 0 + sigs = [sign(msg, k, domain=domain) for k in privkeys] pubs = [privtopub(k) for k in privkeys] - aggsig = aggregate_sigs(sigs) - aggpub = aggregate_pubs(pubs) - assert verify(msg, aggpub, aggsig) + aggsig = aggregate_signatures(sigs) + aggpub = aggregate_pubkeys(pubs) + assert verify(msg, aggpub, aggsig, domain=domain) + + +@pytest.mark.parametrize( + 'msg_1, msg_2, privkeys_1, privkeys_2', + [ + (b'cow', b'wow', tuple(range(10)), tuple(range(10))), + (b'cow', b'wow', (0, 1, 2, 3), (4, 5, 6, 7)), + (b'cow', b'wow', (0, 1, 2, 3), (2, 3, 4, 5)), + ] +) +def test_multi_aggregation(msg_1, msg_2, privkeys_1, privkeys_2): + domain = 0 + + sigs_1 = [sign(msg_1, k, domain=domain) for k in privkeys_1] + pubs_1 = [privtopub(k) for k in privkeys_1] + aggsig_1 = aggregate_signatures(sigs_1) + aggpub_1 = aggregate_pubkeys(pubs_1) + + sigs_2 = [sign(msg_2, k, domain=domain) for k in privkeys_2] + pubs_2 = [privtopub(k) for k in privkeys_2] + aggsig_2 = aggregate_signatures(sigs_2) + aggpub_2 = aggregate_pubkeys(pubs_2) + + msgs = [msg_1, msg_2] + pubs = [aggpub_1, aggpub_2] + aggsig = aggregate_signatures([aggsig_1, aggsig_2]) + + assert verify_multiple( + pubkeys=pubs, + messages=msgs, + signature=aggsig, + domain=domain, + )