diff --git a/eth/precompiles/ecpairing.py b/eth/precompiles/ecpairing.py index d7628570fc..d00b23ae52 100644 --- a/eth/precompiles/ecpairing.py +++ b/eth/precompiles/ecpairing.py @@ -21,6 +21,7 @@ from eth.utils.bn128 import ( validate_point, + FQP_point_to_FQ2_point, ) from eth.utils.padding import ( pad32, @@ -31,7 +32,7 @@ ) -ZERO = (bn128.FQ2.one(), bn128.FQ2.one(), bn128.FQ2.zero()) +ZERO = bn128.Z2 EXPONENT = bn128.FQ12.one() @@ -74,7 +75,7 @@ def _ecpairing(data: bytes) -> bool: @curry -def _process_point(data_buffer: bytes, exponent: int) -> int: +def _process_point(data_buffer: bytes, exponent: int) -> bn128.FQP: x1, y1, x2_i, x2_r, y2_i, y2_r = _extract_point(data_buffer) p1 = validate_point(x1, y1) @@ -85,17 +86,16 @@ def _process_point(data_buffer: bytes, exponent: int) -> int: fq2_x = bn128.FQ2([x2_r, x2_i]) fq2_y = bn128.FQ2([y2_r, y2_i]) + p2 = ZERO if (fq2_x, fq2_y) != (bn128.FQ2.zero(), bn128.FQ2.zero()): p2 = (fq2_x, fq2_y, bn128.FQ2.one()) if not bn128.is_on_curve(p2, bn128.b2): raise ValidationError("point is not on curve") - else: - p2 = ZERO if bn128.multiply(p2, bn128.curve_order)[-1] != bn128.FQ2.zero(): raise ValidationError("TODO: what case is this?????") - return exponent * bn128.pairing(p2, p1, final_exponentiate=False) + return exponent * bn128.pairing(FQP_point_to_FQ2_point(p2), p1, final_exponentiate=False) def _extract_point(data_slice: bytes) -> Tuple[int, int, int, int, int, int]: diff --git a/eth/utils/bls.py b/eth/utils/bls.py index e87ebcbba9..c8dfffc158 100644 --- a/eth/utils/bls.py +++ b/eth/utils/bls.py @@ -2,6 +2,7 @@ Dict, Iterable, Tuple, + Union, ) @@ -16,6 +17,7 @@ FQ, FQ2, FQ12, + FQP, pairing, normalize, field_modulus, @@ -26,6 +28,9 @@ final_exponentiate ) from eth.utils.blake import blake +from eth.utils.bn128 import ( + FQP_point_to_FQ2_point, +) CACHE = {} # type: Dict[bytes, Tuple[FQ2, FQ2, FQ2]] @@ -38,7 +43,7 @@ assert HEX_ROOT ** 16 == FQ2([1, 0]) -def compress_G1(pt: Tuple[FQ2, FQ2, FQ2]) -> int: +def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int: x, y = normalize(pt) return x.n + 2**255 * (y.n % 2) @@ -55,11 +60,11 @@ def decompress_G1(p: int) -> Tuple[FQ, FQ, FQ]: return (FQ(x), FQ(y), FQ(1)) -def sqrt_fq2(x: FQ2) -> FQ2: +def sqrt_fq2(x: FQP) -> FQ2: y = x ** ((field_modulus ** 2 + 15) // 32) while y**2 != x: y *= HEX_ROOT - return y + return FQ2(y.coeffs) def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]: @@ -79,18 +84,22 @@ def hash_to_G2(m: bytes) -> Tuple[FQ2, FQ2, FQ2]: if xcb ** ((field_modulus ** 2 - 1) // 2) == FQ2([1, 0]): break y = sqrt_fq2(xcb) - o = multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order) + + o = FQP_point_to_FQ2_point(multiply((x, y, FQ2([1, 0])), 2 * field_modulus - curve_order)) CACHE[m] = o return o -def compress_G2(pt: Tuple[FQ2, FQ2, FQ2]) -> Tuple[int, int]: +def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]: assert is_on_curve(pt, b2) x, y = normalize(pt) - return (x.coeffs[0] + 2**255 * (y.coeffs[0] % 2), x.coeffs[1]) + return ( + int(x.coeffs[0] + 2**255 * (y.coeffs[0] % 2)), + int(x.coeffs[1]) + ) -def decompress_G2(p: bytes) -> Tuple[FQ2, FQ2, FQ2]: +def decompress_G2(p: bytes) -> Tuple[FQP, FQP, FQP]: x1 = p[0] % 2**255 y1_mod_2 = p[0] // 2**255 x2 = p[1] @@ -99,7 +108,7 @@ def decompress_G2(p: bytes) -> Tuple[FQ2, FQ2, FQ2]: return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0]) y = sqrt_fq2(x**3 + b2) if y.coeffs[0] % 2 != y1_mod_2: - y = y * -1 + y = FQ2((y * -1).coeffs) assert is_on_curve((x, y, FQ2([1, 0])), b2) return x, y, FQ2([1, 0]) @@ -114,8 +123,8 @@ def privtopub(k: int) -> int: def verify(m: bytes, pub: int, sig: bytes) -> bool: final_exponentiation = final_exponentiate( - pairing(decompress_G2(sig), G1, False) * - pairing(hash_to_G2(m), neg(decompress_G1(pub)), False) + pairing(FQP_point_to_FQ2_point(decompress_G2(sig)), G1, False) * + pairing(FQP_point_to_FQ2_point(hash_to_G2(m)), neg(decompress_G1(pub)), False) ) return final_exponentiation == FQ12.one() @@ -123,7 +132,7 @@ def verify(m: bytes, pub: int, sig: bytes) -> bool: def aggregate_sigs(sigs: Iterable[bytes]) -> Tuple[int, int]: o = Z2 for s in sigs: - o = add(o, decompress_G2(s)) + o = FQP_point_to_FQ2_point(add(o, decompress_G2(s))) return compress_G2(o) diff --git a/eth/utils/bn128.py b/eth/utils/bn128.py index b5b2766928..a1a31a4092 100644 --- a/eth/utils/bn128.py +++ b/eth/utils/bn128.py @@ -1,6 +1,10 @@ from py_ecc import ( optimized_bn128 as bn128, ) +from py_ecc.optimized_bn128 import ( + FQP, + FQ2, +) from eth_utils import ( ValidationError, @@ -25,3 +29,14 @@ def validate_point(x: int, y: int) -> Tuple[bn128.FQ, bn128.FQ, bn128.FQ]: p1 = (FQ(1), FQ(1), FQ(0)) return p1 + + +def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]: + """ + Transform FQP to FQ2 for type hinting. + """ + return ( + FQ2(pt[0].coeffs), + FQ2(pt[1].coeffs), + FQ2(pt[2].coeffs), + ) diff --git a/setup.py b/setup.py index c40313d1e1..551059c73c 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ "eth-utils>=1.3.0b0,<2.0.0", "lru-dict>=1.1.6", "mypy_extensions>=0.4.1,<1.0.0", - "py-ecc==1.4.3", + "py-ecc>=1.4.7,<2.0.0", "pyethash>=0.1.27,<1.0.0", "rlp>=1.0.3,<2.0.0", "trie>=1.3.5,<2.0.0",