From c17926fbcfe88ec23fa76685587e145d2803da77 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Tue, 11 Dec 2018 00:42:27 +0800 Subject: [PATCH 1/6] Use BLS12-318 curve and add multi_verify --- eth/utils/bls.py | 163 ++++++++++++++++++++----------- tests/core/bls-utils/test_bls.py | 54 +++++++--- 2 files changed, 149 insertions(+), 68 deletions(-) diff --git a/eth/utils/bls.py b/eth/utils/bls.py index 0ea8008ef9..30f8fcd6dc 100644 --- a/eth/utils/bls.py +++ b/eth/utils/bls.py @@ -5,8 +5,11 @@ Union, ) +from eth_utils import ( + big_endian_to_int, +) -from py_ecc.optimized_bn128 import ( # NOQA +from py_ecc.optimized_bls12_381 import ( # NOQA G1, G2, Z1, @@ -20,7 +23,7 @@ FQP, pairing, normalize, - field_modulus, + field_modulus as q, b, b2, is_on_curve, @@ -28,103 +31,134 @@ final_exponentiate ) from eth.beacon.utils.hash import hash_ -from eth.utils.bn128 import ( - FQP_point_to_FQ2_point, -) -CACHE = {} # type: Dict[bytes, Tuple[FQ2, FQ2, FQ2]] -# 16th root of unity -HEX_ROOT = FQ2([21573744529824266246521972077326577680729363968861965890554801909984373949499, - 16854739155576650954933913186877292401521110422362946064090026408937773542853]) +G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501 +qmod = q ** 2 - 1 +eighth_roots_of_unity = [ + FQ2([1, 1]) ** ((qmod * k) // 8) + for k in range(8) +] -assert HEX_ROOT ** 8 != FQ2([1, 0]) -assert HEX_ROOT ** 16 == FQ2([1, 0]) +# +# Helpers +# +def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]: + """ + Transform FQP to FQ2 for type hinting. + """ + return ( + FQ2(pt[0].coeffs), + FQ2(pt[1].coeffs), + FQ2(pt[2].coeffs), + ) +def modular_squareroot(value: int) -> int: + """ + ``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``, + and None if this is not possible. In cases where there are two solutions, + the value with higher imaginary component is favored; + if both solutions have equal imaginary component the value with higher real + component is favored. + """ + candidate_squareroot = value ** ((qmod + 8) // 16) + check = candidate_squareroot ** 2 / value + if check in eighth_roots_of_unity[::2]: + x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2] + x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]]) + # x2 = - x2 + return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2 + return None + + +def hash_to_G2(message: bytes, domain: int) -> Tuple[FQ2, FQ2, FQ2]: + domain_in_bytes = domain.to_bytes(8, 'big') + x1 = big_endian_to_int(hash_(domain_in_bytes + b'\x01' + message)) + x2 = big_endian_to_int(hash_(domain_in_bytes + b'\x02' + message)) + x_coordinate = FQ2([x1, x2]) # x1 + x2 * i + while 1: + x_cubed_plus_b2 = x_coordinate ** 3 + FQ2([4, 4]) + y_coordinate = modular_squareroot(x_cubed_plus_b2) + if y_coordinate is not None: + break + x_coordinate += FQ2([1, 0]) # Add one until we get a quadratic residue + + return multiply( + (x_coordinate, y_coordinate, FQ2([1, 0])), + G2_cofactor + ) + + +# +# G1 +# def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int: x, y = normalize(pt) - return x.n + 2**255 * (y.n % 2) + return x.n + 2**383 * (y.n % 2) def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]: if p == 0: return (FQ(1), FQ(1), FQ(0)) - x = p % 2**255 - y_mod_2 = p // 2**255 - y = pow((x**3 + b.n) % field_modulus, (field_modulus + 1) // 4, field_modulus) - assert pow(y, 2, field_modulus) == (x**3 + b.n) % field_modulus + x = p % 2**383 + y_mod_2 = p // 2**383 + y = pow((x**3 + b.n) % q, (q + 1) // 4, q) + assert pow(y, 2, q) == (x**3 + b.n) % q if y % 2 != y_mod_2: - y = field_modulus - y + y = q - y return (FQ(x), FQ(y), FQ(1)) -def sqrt_fq2(x: FQP) -> FQ2: - y = x ** ((field_modulus ** 2 + 15) // 32) - while y**2 != x: - y *= HEX_ROOT - return FQ2(y.coeffs) - - -def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]: - """ - WARNING: this function has not been standardized yet. - """ - if m in CACHE: - return CACHE[m] - k2 = m - while 1: - k1 = hash_(k2) - k2 = hash_(k1) - x1 = int.from_bytes(k1, 'big') % field_modulus - x2 = int.from_bytes(k2, 'big') % field_modulus - x = FQ2([x1, x2]) - xcb = x**3 + b2 - if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1, 0]): - break - y = sqrt_fq2(xcb) - - o = FQP_point_to_FQ2_point(multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order)) - CACHE[m] = o - return o - - +# +# G2 +# def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]: assert is_on_curve(pt, b2) x, y = normalize(pt) return ( - int(x.coeffs[0] + 2**255 * (y.coeffs[0] % 2)), + int(x.coeffs[0] + 2**383 * (y.coeffs[0] % 2)), int(x.coeffs[1]) ) def decompress_G2(p: bytes) -> Tuple[FQP, FQP, FQP]: - x1 = p[0] % 2**255 - y1_mod_2 = p[0] // 2**255 + x1 = p[0] % 2**383 + y1_mod_2 = p[0] // 2**383 x2 = p[1] x = FQ2([x1, x2]) if x == FQ2([0, 0]): return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0]) - y = sqrt_fq2(x**3 + b2) + y = modular_squareroot(x**3 + b2) if y.coeffs[0] % 2 != y1_mod_2: y = FQ2((y * -1).coeffs) assert is_on_curve((x, y, FQ2([1, 0])), b2) return x, y, FQ2([1, 0]) -def sign(m: bytes, k: int) -> Tuple[int, int]: - return compress_G2(multiply(hash_to_G2(m), k)) +# +# APIs +# +def sign(message: bytes, + privkey: int, + domain: int) -> Tuple[int, int]: + return compress_G2( + multiply( + hash_to_G2(message, domain), + privkey + ) + ) def privtopub(k: int) -> int: return compress_G1(multiply(G1, k)) -def verify(m: bytes, pub: int, sig: bytes) -> bool: +def verify(m: bytes, pub: int, sig: bytes, domain: int) -> bool: final_exponentiation = final_exponentiate( pairing(FQP_point_to_FQ2_point(decompress_G2(sig)), G1, False) * - pairing(FQP_point_to_FQ2_point(hash_to_G2(m)), neg(decompress_G1(pub)), False) + pairing(FQP_point_to_FQ2_point(hash_to_G2(m, domain)), neg(decompress_G1(pub)), False) ) return final_exponentiation == FQ12.one() @@ -141,3 +175,22 @@ def aggregate_pubs(pubs: Iterable[int]) -> int: for p in pubs: o = add(o, decompress_G1(p)) return compress_G1(o) + + +def multi_verify(pubs, msgs, sig, domain): + len_msgs = len(msgs) + assert len(pubs) == len_msgs + + o = FQ12([1] + [0] * 11) + for m in set(msgs): + # aggregate the pubs + group_pub = Z1 + for i in range(len_msgs): + if msgs[i] == m: + group_pub = add(group_pub, decompress_G1(pubs[i])) + + o *= pairing(hash_to_G2(m, domain), group_pub, False) + o *= pairing(decompress_G2(sig), neg(G1), False) + + final_exponentiation = final_exponentiate(o) + return final_exponentiation == FQ12.one() diff --git a/tests/core/bls-utils/test_bls.py b/tests/core/bls-utils/test_bls.py index b6d76a6b6a..81abc04bbc 100644 --- a/tests/core/bls-utils/test_bls.py +++ b/tests/core/bls-utils/test_bls.py @@ -1,6 +1,5 @@ import pytest -pytest.importorskip('eth.utils.bls') # noqa E402 from eth.utils.bls import ( G1, G2, @@ -15,15 +14,11 @@ privtopub, aggregate_sigs, aggregate_pubs, - verify -) - -from tests.core.helpers import ( - greater_equal_python36, + verify, + multi_verify, ) -@greater_equal_python36 @pytest.mark.parametrize( 'privkey', [ @@ -37,19 +32,20 @@ ] ) def test_bls_core(privkey): + domain = 0 p1 = multiply(G1, privkey) p2 = multiply(G2, privkey) msg = str(privkey).encode('utf-8') - msghash = hash_to_G2(msg) + msghash = hash_to_G2(msg, domain=domain) + assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1) assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2) assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash) - sig = sign(msg, privkey) + sig = sign(msg, privkey, domain=domain) pub = privtopub(privkey) - assert verify(msg, pub, sig) + assert verify(msg, pub, sig, domain=domain) -@greater_equal_python36 @pytest.mark.parametrize( 'msg, privkeys', [ @@ -58,8 +54,40 @@ def test_bls_core(privkey): ] ) def test_signature_aggregation(msg, privkeys): - sigs = [sign(msg, k) for k in privkeys] + domain = 0 + sigs = [sign(msg, k, domain=domain) for k in privkeys] pubs = [privtopub(k) for k in privkeys] aggsig = aggregate_sigs(sigs) aggpub = aggregate_pubs(pubs) - assert verify(msg, aggpub, aggsig) + assert verify(msg, aggpub, aggsig, domain=domain) + + +@pytest.mark.parametrize( + 'msg_1, msg_2, privkeys', + [ + (b'cow', b'wow', [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]), + ] +) +def test_multi_aggregation(msg_1, msg_2, privkeys): + domain = 0 + sigs_1 = [sign(msg_1, k, domain=domain) for k in privkeys] + + pubs_1 = [privtopub(k) for k in privkeys] + aggsig_1 = aggregate_sigs(sigs_1) + aggpub_1 = aggregate_pubs(pubs_1) + + sigs_2 = [sign(msg_2, k, domain=domain) for k in privkeys] + pubs_2 = [privtopub(k) for k in privkeys] + aggsig_2 = aggregate_sigs(sigs_2) + aggpub_2 = aggregate_pubs(pubs_2) + + msgs = [msg_1, msg_2] + pubs = [aggpub_1, aggpub_2] + aggsig = aggregate_sigs([aggsig_1, aggsig_2]) + + assert multi_verify( + pubs=pubs, + msgs=msgs, + sig=aggsig, + domain=domain, + ) From 19ebc7bdce1ee22e5dedd1625bccc64e66ecbe01 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Wed, 12 Dec 2018 11:58:04 +0800 Subject: [PATCH 2/6] Add `domain` to beacon chain aggregation APIs --- eth/beacon/aggregation.py | 7 +++++-- eth/utils/bls.py | 13 ++++++++----- tests/beacon/test_aggregation.py | 11 ++++++++--- tests/core/bls-utils/test_bls.py | 4 ++-- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/eth/beacon/aggregation.py b/eth/beacon/aggregation.py index 0def06cb8a..c5f83f31ef 100644 --- a/eth/beacon/aggregation.py +++ b/eth/beacon/aggregation.py @@ -16,6 +16,8 @@ ) from eth.beacon.utils.hash import hash_ +from eth.beacon.enums.signature_domain import SignatureDomain + def create_signing_message(slot: int, parent_hashes: Iterable[Hash32], @@ -37,7 +39,8 @@ def create_signing_message(slot: int, def verify_votes( message: bytes, - votes: Iterable[Tuple[int, bytes, int]]) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]: + votes: Iterable[Tuple[int, bytes, int]], + domain: SignatureDomain) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]: """ Verify the given votes. @@ -47,7 +50,7 @@ def verify_votes( (sig, committee_index) for (committee_index, sig, public_key) in votes - if bls.verify(message, public_key, sig) + if bls.verify(message, public_key, sig, domain) ) try: sigs, committee_indices = zip(*sigs_with_committe_info) diff --git a/eth/utils/bls.py b/eth/utils/bls.py index 30f8fcd6dc..d2970fbab6 100644 --- a/eth/utils/bls.py +++ b/eth/utils/bls.py @@ -1,6 +1,6 @@ from typing import ( # noqa: F401 Dict, - Iterable, + Sequence, Tuple, Union, ) @@ -55,7 +55,7 @@ def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]: ) -def modular_squareroot(value: int) -> int: +def modular_squareroot(value: int) -> FQP: """ ``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``, and None if this is not possible. In cases where there are two solutions, @@ -163,21 +163,24 @@ def verify(m: bytes, pub: int, sig: bytes, domain: int) -> bool: return final_exponentiation == FQ12.one() -def aggregate_sigs(sigs: Iterable[bytes]) -> Tuple[int, int]: +def aggregate_sigs(sigs: Sequence[bytes]) -> Tuple[int, int]: o = Z2 for s in sigs: o = FQP_point_to_FQ2_point(add(o, decompress_G2(s))) return compress_G2(o) -def aggregate_pubs(pubs: Iterable[int]) -> int: +def aggregate_pubs(pubs: Sequence[int]) -> int: o = Z1 for p in pubs: o = add(o, decompress_G1(p)) return compress_G1(o) -def multi_verify(pubs, msgs, sig, domain): +def verify_multiple(pubs: Sequence[int], + msgs: Sequence[bytes], + sig: bytes, + domain: int) -> bool: len_msgs = len(msgs) assert len(pubs) == len_msgs diff --git a/tests/beacon/test_aggregation.py b/tests/beacon/test_aggregation.py index f0e05109d6..8033eaab7d 100644 --- a/tests/beacon/test_aggregation.py +++ b/tests/beacon/test_aggregation.py @@ -31,18 +31,23 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys): bit_count = 10 pre_bitfield = get_empty_bitfield(bit_count) pre_sigs = () + domain = 0 random_votes = random.sample(range(bit_count), votes_count) message = b'hello' # Get votes: (committee_index, sig, public_key) votes = [ - (committee_index, bls.sign(message, privkeys[committee_index]), pubkeys[committee_index]) + ( + committee_index, + bls.sign(message, privkeys[committee_index], domain), + pubkeys[committee_index], + ) for committee_index in random_votes ] # Verify - sigs, committee_indices = verify_votes(message, votes) + sigs, committee_indices = verify_votes(message, votes, domain) # Aggregate the votes bitfield, sigs = aggregate_votes( @@ -65,4 +70,4 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys): assert len(voted_index) == len(votes) aggregated_pubs = bls.aggregate_pubs(pubs) - assert bls.verify(message, aggregated_pubs, sigs) + assert bls.verify(message, aggregated_pubs, sigs, domain) diff --git a/tests/core/bls-utils/test_bls.py b/tests/core/bls-utils/test_bls.py index 81abc04bbc..2d488fc67d 100644 --- a/tests/core/bls-utils/test_bls.py +++ b/tests/core/bls-utils/test_bls.py @@ -15,7 +15,7 @@ aggregate_sigs, aggregate_pubs, verify, - multi_verify, + verify_multiple, ) @@ -85,7 +85,7 @@ def test_multi_aggregation(msg_1, msg_2, privkeys): pubs = [aggpub_1, aggpub_2] aggsig = aggregate_sigs([aggsig_1, aggsig_2]) - assert multi_verify( + assert verify_multiple( pubs=pubs, msgs=msgs, sig=aggsig, From c0a322666c45a1d7044ffac34db9e200314cbf43 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Wed, 12 Dec 2018 13:31:53 +0800 Subject: [PATCH 3/6] Remove `create_signing_message` --- eth/beacon/aggregation.py | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/eth/beacon/aggregation.py b/eth/beacon/aggregation.py index c5f83f31ef..33eb4ddad6 100644 --- a/eth/beacon/aggregation.py +++ b/eth/beacon/aggregation.py @@ -6,35 +6,12 @@ pipe ) -from eth_typing import ( - Hash32, -) - from eth.utils import bls from eth.utils.bitfield import ( set_voted, ) -from eth.beacon.utils.hash import hash_ - -from eth.beacon.enums.signature_domain import SignatureDomain - -def create_signing_message(slot: int, - parent_hashes: Iterable[Hash32], - shard_id: int, - shard_block_hash: Hash32, - justified_slot: int) -> bytes: - """ - Return the signining message for attesting. - """ - # TODO: Will be updated with SSZ encoded attestation. - return hash_( - slot.to_bytes(8, byteorder='big') + - b''.join(parent_hashes) + - shard_id.to_bytes(2, byteorder='big') + - shard_block_hash + - justified_slot.to_bytes(8, 'big') - ) +from eth.beacon.enums import SignatureDomain def verify_votes( From 30be4954f483a0d0b1e443baa91d7f897d5e9149 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Thu, 13 Dec 2018 15:14:16 +0800 Subject: [PATCH 4/6] PR feedback and update parameters name --- eth/beacon/aggregation.py | 2 +- eth/utils/bls.py | 65 ++++++++++++++++++-------------- tests/core/bls-utils/test_bls.py | 40 ++++++++++---------- 3 files changed, 58 insertions(+), 49 deletions(-) diff --git a/eth/beacon/aggregation.py b/eth/beacon/aggregation.py index 33eb4ddad6..9f9f6bae65 100644 --- a/eth/beacon/aggregation.py +++ b/eth/beacon/aggregation.py @@ -55,4 +55,4 @@ def aggregate_votes(bitfield: bytes, ) ) - return bitfield, bls.aggregate_sigs(sigs) + return bitfield, bls.aggregate_signatures(sigs) diff --git a/eth/utils/bls.py b/eth/utils/bls.py index d2970fbab6..20b5222815 100644 --- a/eth/utils/bls.py +++ b/eth/utils/bls.py @@ -34,9 +34,9 @@ G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501 -qmod = q ** 2 - 1 +FQ2_order = q ** 2 - 1 eighth_roots_of_unity = [ - FQ2([1, 1]) ** ((qmod * k) // 8) + FQ2([1, 1]) ** ((FQ2_order * k) // 8) for k in range(8) ] @@ -63,27 +63,30 @@ def modular_squareroot(value: int) -> FQP: if both solutions have equal imaginary component the value with higher real component is favored. """ - candidate_squareroot = value ** ((qmod + 8) // 16) + candidate_squareroot = value ** ((FQ2_order + 8) // 16) check = candidate_squareroot ** 2 / value if check in eighth_roots_of_unity[::2]: x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2] - x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]]) - # x2 = - x2 + x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]]) # x2 = -x1 return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2 return None def hash_to_G2(message: bytes, domain: int) -> Tuple[FQ2, FQ2, FQ2]: domain_in_bytes = domain.to_bytes(8, 'big') - x1 = big_endian_to_int(hash_(domain_in_bytes + b'\x01' + message)) - x2 = big_endian_to_int(hash_(domain_in_bytes + b'\x02' + message)) - x_coordinate = FQ2([x1, x2]) # x1 + x2 * i + + # Initial candidate x coordinate + x_re = big_endian_to_int(hash_(domain_in_bytes + b'\x01' + message)) + x_im = big_endian_to_int(hash_(domain_in_bytes + b'\x02' + message)) + x_coordinate = FQ2([x_re, x_im]) # x_re + x_im * i + + # Test candidate y coordinates until a one is found while 1: - x_cubed_plus_b2 = x_coordinate ** 3 + FQ2([4, 4]) - y_coordinate = modular_squareroot(x_cubed_plus_b2) - if y_coordinate is not None: + y_coordinate_squared = x_coordinate ** 3 + FQ2([4, 4]) # The curve is y^2 = x^3 + 4(i + 1) + y_coordinate = modular_squareroot(y_coordinate_squared) + if y_coordinate is not None: # Check if quadratic residue found break - x_coordinate += FQ2([1, 0]) # Add one until we get a quadratic residue + x_coordinate += FQ2([1, 0]) # Add 1 and try again return multiply( (x_coordinate, y_coordinate, FQ2([1, 0])), @@ -155,45 +158,49 @@ def privtopub(k: int) -> int: return compress_G1(multiply(G1, k)) -def verify(m: bytes, pub: int, sig: bytes, domain: int) -> bool: +def verify(message: bytes, pubkey: int, signature: bytes, domain: int) -> bool: final_exponentiation = final_exponentiate( - pairing(FQP_point_to_FQ2_point(decompress_G2(sig)), G1, False) * - pairing(FQP_point_to_FQ2_point(hash_to_G2(m, domain)), neg(decompress_G1(pub)), False) + pairing(FQP_point_to_FQ2_point(decompress_G2(signature)), G1, False) * + pairing( + FQP_point_to_FQ2_point(hash_to_G2(message, domain)), + neg(decompress_G1(pubkey)), + False + ) ) return final_exponentiation == FQ12.one() -def aggregate_sigs(sigs: Sequence[bytes]) -> Tuple[int, int]: +def aggregate_signatures(signatures: Sequence[bytes]) -> Tuple[int, int]: o = Z2 - for s in sigs: + for s in signatures: o = FQP_point_to_FQ2_point(add(o, decompress_G2(s))) return compress_G2(o) -def aggregate_pubs(pubs: Sequence[int]) -> int: +def aggregate_pubkeys(pubkeys: Sequence[int]) -> int: o = Z1 - for p in pubs: + for p in pubkeys: o = add(o, decompress_G1(p)) return compress_G1(o) -def verify_multiple(pubs: Sequence[int], - msgs: Sequence[bytes], - sig: bytes, +def verify_multiple(pubkeys: Sequence[int], + messages: Sequence[bytes], + signature: bytes, domain: int) -> bool: - len_msgs = len(msgs) - assert len(pubs) == len_msgs + len_msgs = len(messages) + assert len(pubkeys) == len_msgs o = FQ12([1] + [0] * 11) - for m in set(msgs): + for m_pubs in set(messages): # aggregate the pubs group_pub = Z1 for i in range(len_msgs): - if msgs[i] == m: - group_pub = add(group_pub, decompress_G1(pubs[i])) + if messages[i] == m_pubs: + group_pub = add(group_pub, decompress_G1(pubkeys[i])) - o *= pairing(hash_to_G2(m, domain), group_pub, False) - o *= pairing(decompress_G2(sig), neg(G1), False) + o *= pairing(hash_to_G2(m_pubs, domain), group_pub, False) + o *= pairing(decompress_G2(signature), neg(G1), False) final_exponentiation = final_exponentiate(o) return final_exponentiation == FQ12.one() diff --git a/tests/core/bls-utils/test_bls.py b/tests/core/bls-utils/test_bls.py index 2d488fc67d..8ed98ede7d 100644 --- a/tests/core/bls-utils/test_bls.py +++ b/tests/core/bls-utils/test_bls.py @@ -12,8 +12,8 @@ multiply, sign, privtopub, - aggregate_sigs, - aggregate_pubs, + aggregate_signatures, + aggregate_pubkeys, verify, verify_multiple, ) @@ -57,37 +57,39 @@ def test_signature_aggregation(msg, privkeys): domain = 0 sigs = [sign(msg, k, domain=domain) for k in privkeys] pubs = [privtopub(k) for k in privkeys] - aggsig = aggregate_sigs(sigs) - aggpub = aggregate_pubs(pubs) + aggsig = aggregate_signatures(sigs) + aggpub = aggregate_pubkeys(pubs) assert verify(msg, aggpub, aggsig, domain=domain) @pytest.mark.parametrize( - 'msg_1, msg_2, privkeys', + 'msg_1, msg_2, privkeys_1, privkeys_2', [ - (b'cow', b'wow', [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]), + (b'cow', b'wow', tuple(range(10)), tuple(range(10))), + (b'cow', b'wow', (0, 1, 2, 3), (4, 5, 6, 7)), + (b'cow', b'wow', (0, 1, 2, 3), (2, 3, 4, 5)), ] ) -def test_multi_aggregation(msg_1, msg_2, privkeys): +def test_multi_aggregation(msg_1, msg_2, privkeys_1, privkeys_2): domain = 0 - sigs_1 = [sign(msg_1, k, domain=domain) for k in privkeys] - pubs_1 = [privtopub(k) for k in privkeys] - aggsig_1 = aggregate_sigs(sigs_1) - aggpub_1 = aggregate_pubs(pubs_1) + sigs_1 = [sign(msg_1, k, domain=domain) for k in privkeys_1] + pubs_1 = [privtopub(k) for k in privkeys_1] + aggsig_1 = aggregate_signatures(sigs_1) + aggpub_1 = aggregate_pubkeys(pubs_1) - sigs_2 = [sign(msg_2, k, domain=domain) for k in privkeys] - pubs_2 = [privtopub(k) for k in privkeys] - aggsig_2 = aggregate_sigs(sigs_2) - aggpub_2 = aggregate_pubs(pubs_2) + sigs_2 = [sign(msg_2, k, domain=domain) for k in privkeys_2] + pubs_2 = [privtopub(k) for k in privkeys_2] + aggsig_2 = aggregate_signatures(sigs_2) + aggpub_2 = aggregate_pubkeys(pubs_2) msgs = [msg_1, msg_2] pubs = [aggpub_1, aggpub_2] - aggsig = aggregate_sigs([aggsig_1, aggsig_2]) + aggsig = aggregate_signatures([aggsig_1, aggsig_2]) assert verify_multiple( - pubs=pubs, - msgs=msgs, - sig=aggsig, + pubkeys=pubs, + messages=msgs, + signature=aggsig, domain=domain, ) From 118f20f5fd6b0eafb920cc644a1affc81fca9c94 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Thu, 13 Dec 2018 15:31:49 +0800 Subject: [PATCH 5/6] fix --- tests/beacon/test_aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/beacon/test_aggregation.py b/tests/beacon/test_aggregation.py index 8033eaab7d..79c2ee3f84 100644 --- a/tests/beacon/test_aggregation.py +++ b/tests/beacon/test_aggregation.py @@ -69,5 +69,5 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys): ] assert len(voted_index) == len(votes) - aggregated_pubs = bls.aggregate_pubs(pubs) + aggregated_pubs = bls.aggregate_pubkeys(pubs) assert bls.verify(message, aggregated_pubs, sigs, domain) From 9df72312f36ccecc6a196631f00a6d0b53e8ace8 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Thu, 13 Dec 2018 15:40:42 +0800 Subject: [PATCH 6/6] Raise exception in BLS --- eth/utils/bls.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/eth/utils/bls.py b/eth/utils/bls.py index 20b5222815..4b006d549f 100644 --- a/eth/utils/bls.py +++ b/eth/utils/bls.py @@ -7,6 +7,7 @@ from eth_utils import ( big_endian_to_int, + ValidationError, ) from py_ecc.optimized_bls12_381 import ( # NOQA @@ -102,13 +103,17 @@ def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int: return x.n + 2**383 * (y.n % 2) -def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]: - if p == 0: +def decompress_G1(pt: int) -> Tuple[FQ, FQ, FQ]: + if pt == 0: return (FQ(1), FQ(1), FQ(0)) - x = p % 2**383 - y_mod_2 = p // 2**383 + x = pt % 2**383 + y_mod_2 = pt // 2**383 y = pow((x**3 + b.n) % q, (q + 1) // 4, q) - assert pow(y, 2, q) == (x**3 + b.n) % q + + if pow(y, 2, q) != (x**3 + b.n) % q: + raise ValueError( + "he given point is not on G1: y**2 = x**3 + b" + ) if y % 2 != y_mod_2: y = q - y return (FQ(x), FQ(y), FQ(1)) @@ -118,7 +123,10 @@ def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]: # G2 # def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]: - assert is_on_curve(pt, b2) + if not is_on_curve(pt, b2): + raise ValueError( + "The given point is not on the twisted curve over FQ**2" + ) x, y = normalize(pt) return ( int(x.coeffs[0] + 2**383 * (y.coeffs[0] % 2)), @@ -136,7 +144,10 @@ def decompress_G2(p: bytes) -> Tuple[FQP, FQP, FQP]: y = modular_squareroot(x**3 + b2) if y.coeffs[0] % 2 != y1_mod_2: y = FQ2((y * -1).coeffs) - assert is_on_curve((x, y, FQ2([1, 0])), b2) + if not is_on_curve((x, y, FQ2([1, 0])), b2): + raise ValueError( + "The given point is not on the twisted curve over FQ**2" + ) return x, y, FQ2([1, 0]) @@ -189,7 +200,13 @@ def verify_multiple(pubkeys: Sequence[int], signature: bytes, domain: int) -> bool: len_msgs = len(messages) - assert len(pubkeys) == len_msgs + + if len(pubkeys) != len_msgs: + raise ValidationError( + "len(pubkeys) (%s) should be equal to len(messages) (%s)" % ( + len(pubkeys), len_msgs + ) + ) o = FQ12([1] + [0] * 11) for m_pubs in set(messages):