Skip to content

Update BLS signature module #1581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 5 additions & 25 deletions eth/beacon/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,18 @@
pipe
)

from eth_typing import (
Hash32,
)

from eth.utils import bls
from eth.utils.bitfield import (
set_voted,
)
from eth.beacon.utils.hash import hash_


def create_signing_message(slot: int,
parent_hashes: Iterable[Hash32],
shard_id: int,
shard_block_hash: Hash32,
justified_slot: int) -> bytes:
"""
Return the signining message for attesting.
"""
# TODO: Will be updated with SSZ encoded attestation.
return hash_(
slot.to_bytes(8, byteorder='big') +
b''.join(parent_hashes) +
shard_id.to_bytes(2, byteorder='big') +
shard_block_hash +
justified_slot.to_bytes(8, 'big')
)
from eth.beacon.enums import SignatureDomain


def verify_votes(
message: bytes,
votes: Iterable[Tuple[int, bytes, int]]) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]:
votes: Iterable[Tuple[int, bytes, int]],
domain: SignatureDomain) -> Tuple[Tuple[bytes, ...], Tuple[int, ...]]:
"""
Verify the given votes.

Expand All @@ -47,7 +27,7 @@ def verify_votes(
(sig, committee_index)
for (committee_index, sig, public_key)
in votes
if bls.verify(message, public_key, sig)
if bls.verify(message, public_key, sig, domain)
)
try:
sigs, committee_indices = zip(*sigs_with_committe_info)
Expand Down Expand Up @@ -75,4 +55,4 @@ def aggregate_votes(bitfield: bytes,
)
)

return bitfield, bls.aggregate_sigs(sigs)
return bitfield, bls.aggregate_signatures(sigs)
210 changes: 145 additions & 65 deletions eth/utils/bls.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import ( # noqa: F401
Dict,
Iterable,
Sequence,
Tuple,
Union,
)

from eth_utils import (
big_endian_to_int,
ValidationError,
)

from py_ecc.optimized_bn128 import ( # NOQA
from py_ecc.optimized_bls12_381 import ( # NOQA
G1,
G2,
Z1,
Expand All @@ -20,124 +24,200 @@
FQP,
pairing,
normalize,
field_modulus,
field_modulus as q,
b,
b2,
is_on_curve,
curve_order,
final_exponentiate
)
from eth.beacon.utils.hash import hash_
from eth.utils.bn128 import (
FQP_point_to_FQ2_point,
)


CACHE = {} # type: Dict[bytes, Tuple[FQ2, FQ2, FQ2]]
# 16th root of unity
HEX_ROOT = FQ2([21573744529824266246521972077326577680729363968861965890554801909984373949499,
16854739155576650954933913186877292401521110422362946064090026408937773542853])

G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501
FQ2_order = q ** 2 - 1
eighth_roots_of_unity = [
FQ2([1, 1]) ** ((FQ2_order * k) // 8)
for k in range(8)
]

assert HEX_ROOT ** 8 != FQ2([1, 0])
assert HEX_ROOT ** 16 == FQ2([1, 0])


def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int:
x, y = normalize(pt)
return x.n + 2**255 * (y.n % 2)
#
# Helpers
#
def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]:
"""
Transform FQP to FQ2 for type hinting.
"""
return (
FQ2(pt[0].coeffs),
FQ2(pt[1].coeffs),
FQ2(pt[2].coeffs),
)


def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]:
if p == 0:
return (FQ(1), FQ(1), FQ(0))
x = p % 2**255
y_mod_2 = p // 2**255
y = pow((x**3 + b.n) % field_modulus, (field_modulus + 1) // 4, field_modulus)
assert pow(y, 2, field_modulus) == (x**3 + b.n) % field_modulus
if y % 2 != y_mod_2:
y = field_modulus - y
return (FQ(x), FQ(y), FQ(1))
def modular_squareroot(value: int) -> FQP:
"""
``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``,
and None if this is not possible. In cases where there are two solutions,
the value with higher imaginary component is favored;
if both solutions have equal imaginary component the value with higher real
component is favored.
"""
candidate_squareroot = value ** ((FQ2_order + 8) // 16)
check = candidate_squareroot ** 2 / value
if check in eighth_roots_of_unity[::2]:
x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2]
x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]]) # x2 = -x1
return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2
return None


def sqrt_fq2(x: FQP) -> FQ2:
y = x ** ((field_modulus ** 2 + 15) // 32)
while y**2 != x:
y *= HEX_ROOT
return FQ2(y.coeffs)
def hash_to_G2(message: bytes, domain: int) -> Tuple[FQ2, FQ2, FQ2]:
domain_in_bytes = domain.to_bytes(8, 'big')

# Initial candidate x coordinate
x_re = big_endian_to_int(hash_(domain_in_bytes + b'\x01' + message))
x_im = big_endian_to_int(hash_(domain_in_bytes + b'\x02' + message))
x_coordinate = FQ2([x_re, x_im]) # x_re + x_im * i

def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]:
"""
WARNING: this function has not been standardized yet.
"""
if m in CACHE:
return CACHE[m]
k2 = m
# Test candidate y coordinates until a one is found
while 1:
k1 = hash_(k2)
k2 = hash_(k1)
x1 = int.from_bytes(k1, 'big') % field_modulus
x2 = int.from_bytes(k2, 'big') % field_modulus
x = FQ2([x1, x2])
xcb = x**3 + b2
if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1, 0]):
y_coordinate_squared = x_coordinate ** 3 + FQ2([4, 4]) # The curve is y^2 = x^3 + 4(i + 1)
y_coordinate = modular_squareroot(y_coordinate_squared)
if y_coordinate is not None: # Check if quadratic residue found
break
y = sqrt_fq2(xcb)
x_coordinate += FQ2([1, 0]) # Add 1 and try again

o = FQP_point_to_FQ2_point(multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order))
CACHE[m] = o
return o
return multiply(
(x_coordinate, y_coordinate, FQ2([1, 0])),
G2_cofactor
)


#
# G1
#
def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int:
x, y = normalize(pt)
return x.n + 2**383 * (y.n % 2)


def decompress_G1(pt: int) -> Tuple[FQ, FQ, FQ]:
if pt == 0:
return (FQ(1), FQ(1), FQ(0))
x = pt % 2**383
y_mod_2 = pt // 2**383
y = pow((x**3 + b.n) % q, (q + 1) // 4, q)

if pow(y, 2, q) != (x**3 + b.n) % q:
raise ValueError(
"he given point is not on G1: y**2 = x**3 + b"
)
if y % 2 != y_mod_2:
y = q - y
return (FQ(x), FQ(y), FQ(1))


#
# G2
#
def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]:
assert is_on_curve(pt, b2)
if not is_on_curve(pt, b2):
raise ValueError(
"The given point is not on the twisted curve over FQ**2"
)
x, y = normalize(pt)
return (
int(x.coeffs[0] + 2**255 * (y.coeffs[0] % 2)),
int(x.coeffs[0] + 2**383 * (y.coeffs[0] % 2)),
int(x.coeffs[1])
)


def decompress_G2(p: bytes) -> Tuple[FQP, FQP, FQP]:
x1 = p[0] % 2**255
y1_mod_2 = p[0] // 2**255
x1 = p[0] % 2**383
y1_mod_2 = p[0] // 2**383
x2 = p[1]
x = FQ2([x1, x2])
if x == FQ2([0, 0]):
return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0])
y = sqrt_fq2(x**3 + b2)
y = modular_squareroot(x**3 + b2)
if y.coeffs[0] % 2 != y1_mod_2:
y = FQ2((y * -1).coeffs)
assert is_on_curve((x, y, FQ2([1, 0])), b2)
if not is_on_curve((x, y, FQ2([1, 0])), b2):
raise ValueError(
"The given point is not on the twisted curve over FQ**2"
)
return x, y, FQ2([1, 0])


def sign(m: bytes, k: int) -> Tuple[int, int]:
return compress_G2(multiply(hash_to_G2(m), k))
#
# APIs
#
def sign(message: bytes,
privkey: int,
domain: int) -> Tuple[int, int]:
return compress_G2(
multiply(
hash_to_G2(message, domain),
privkey
)
)


def privtopub(k: int) -> int:
return compress_G1(multiply(G1, k))


def verify(m: bytes, pub: int, sig: bytes) -> bool:
def verify(message: bytes, pubkey: int, signature: bytes, domain: int) -> bool:
final_exponentiation = final_exponentiate(
pairing(FQP_point_to_FQ2_point(decompress_G2(sig)), G1, False) *
pairing(FQP_point_to_FQ2_point(hash_to_G2(m)), neg(decompress_G1(pub)), False)
pairing(FQP_point_to_FQ2_point(decompress_G2(signature)), G1, False) *
pairing(
FQP_point_to_FQ2_point(hash_to_G2(message, domain)),
neg(decompress_G1(pubkey)),
False
)
)
return final_exponentiation == FQ12.one()


def aggregate_sigs(sigs: Iterable[bytes]) -> Tuple[int, int]:
def aggregate_signatures(signatures: Sequence[bytes]) -> Tuple[int, int]:
o = Z2
for s in sigs:
for s in signatures:
o = FQP_point_to_FQ2_point(add(o, decompress_G2(s)))
return compress_G2(o)


def aggregate_pubs(pubs: Iterable[int]) -> int:
def aggregate_pubkeys(pubkeys: Sequence[int]) -> int:
o = Z1
for p in pubs:
for p in pubkeys:
o = add(o, decompress_G1(p))
return compress_G1(o)


def verify_multiple(pubkeys: Sequence[int],
messages: Sequence[bytes],
signature: bytes,
domain: int) -> bool:
len_msgs = len(messages)

if len(pubkeys) != len_msgs:
raise ValidationError(
"len(pubkeys) (%s) should be equal to len(messages) (%s)" % (
len(pubkeys), len_msgs
)
)

o = FQ12([1] + [0] * 11)
for m_pubs in set(messages):
# aggregate the pubs
group_pub = Z1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was a bit unclear in the last comment. I meant something like the following which utilizes aggregate_pubs rather than directly calling add

for message in set(messages):
    pubkeys_for_message = [
        pubkey for i, pubkey in enumerate(pubkeys)
        if messages[i] == message
    ]
    group_pub = aggregate_pubkeys(pubkeys_for_message)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see! I think I've tried to do that before, if we want to utilize aggregate_pubkeys, we will need:

  1. len(messages) times extra compress_G1() calling (inside aggregate_pubkeys)
  2. len(messages) times extra decompress_G1() calling for the second argument of pairing.

Alternatively, we might refactor aggregate_signatures(signatures: Sequence[bytes]) -> Tuple[int, int] to:

def aggregate_pubkeys(pubkeys: Sequence[int]) -> int:
    return compress_G1(_aggregate_pubkeys(pubkeys))

def _aggregate_pubkeys(pubkeys: Sequence[int]) -> Tuple[FQ, FQ, FQ]:
    o = Z1
    for p in pubkeys:
        o = add(o, decompress_G1(p))
    return o

And then make verify_multiple call _aggregate_pubkeys().

What do you think of it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh crap, you're right.

I'm okay either way. It's not a huge gain in code reuse and this code won't change much once in place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do whichever you want. I'm going to use the updated work in #1631 in the morning so merge whenever you're ready.

for i in range(len_msgs):
if messages[i] == m_pubs:
group_pub = add(group_pub, decompress_G1(pubkeys[i]))

o *= pairing(hash_to_G2(m_pubs, domain), group_pub, False)
o *= pairing(decompress_G2(signature), neg(G1), False)

final_exponentiation = final_exponentiate(o)
return final_exponentiation == FQ12.one()
13 changes: 9 additions & 4 deletions tests/beacon/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,23 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys):
bit_count = 10
pre_bitfield = get_empty_bitfield(bit_count)
pre_sigs = ()
domain = 0

random_votes = random.sample(range(bit_count), votes_count)
message = b'hello'

# Get votes: (committee_index, sig, public_key)
votes = [
(committee_index, bls.sign(message, privkeys[committee_index]), pubkeys[committee_index])
(
committee_index,
bls.sign(message, privkeys[committee_index], domain),
pubkeys[committee_index],
)
for committee_index in random_votes
]

# Verify
sigs, committee_indices = verify_votes(message, votes)
sigs, committee_indices = verify_votes(message, votes, domain)

# Aggregate the votes
bitfield, sigs = aggregate_votes(
Expand All @@ -64,5 +69,5 @@ def test_aggregate_votes(votes_count, random, privkeys, pubkeys):
]
assert len(voted_index) == len(votes)

aggregated_pubs = bls.aggregate_pubs(pubs)
assert bls.verify(message, aggregated_pubs, sigs)
aggregated_pubs = bls.aggregate_pubkeys(pubs)
assert bls.verify(message, aggregated_pubs, sigs, domain)
Loading