Skip to content

Commit 81ec8ee

Browse files
fselmoSamWilsngurukamath
authored
Optimized bls12 381 (#1268)
* Implement bls12_381 optimized for Prague - Change all methods to use the optimized ``py_ecc`` classes for bls12_381_g1, bls12_381_g2, and bls12_381_pairing. * Add some entries to whitelist.txt * Remove optimized from bls imports * De-pythonize conditionals * consistent naming convention --------- Co-authored-by: Sam Wilson <[email protected]> Co-authored-by: Guruprasad Kamath <[email protected]>
1 parent 3aabf4b commit 81ec8ee

File tree

5 files changed

+145
-139
lines changed

5 files changed

+145
-139
lines changed

src/ethereum/prague/vm/precompiled_contracts/bls12_381/__init__.py

Lines changed: 80 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,29 @@
1111
1212
Precompile for BLS12-381 curve operations.
1313
"""
14-
from typing import Tuple, Union
14+
15+
from typing import Tuple
1516

1617
from ethereum_types.bytes import Bytes
1718
from ethereum_types.numeric import U256, Uint
18-
from py_ecc.bls12_381.bls12_381_curve import (
19+
from py_ecc.optimized_bls12_381.optimized_curve import (
1920
FQ,
2021
FQ2,
2122
b,
2223
b2,
2324
curve_order,
25+
is_inf,
2426
is_on_curve,
25-
multiply,
2627
)
27-
from py_ecc.optimized_bls12_381.optimized_curve import FQ as OPTIMIZED_FQ
28-
from py_ecc.optimized_bls12_381.optimized_curve import FQ2 as OPTIMIZED_FQ2
29-
from py_ecc.typing import Point2D
28+
from py_ecc.optimized_bls12_381.optimized_curve import (
29+
multiply as bls12_multiply,
30+
)
31+
from py_ecc.optimized_bls12_381.optimized_curve import normalize
32+
from py_ecc.typing import Optimized_Point3D as Point3D
3033

3134
from ....vm.memory import buffer_read
3235
from ...exceptions import InvalidParameter
3336

34-
P = FQ.field_modulus
35-
3637
G1_K_DISCOUNT = [
3738
1000,
3839
949,
@@ -300,7 +301,9 @@
300301
MULTIPLIER = Uint(1000)
301302

302303

303-
def bytes_to_G1(data: Bytes) -> Point2D:
304+
def bytes_to_g1(
305+
data: Bytes,
306+
) -> Point3D[FQ]:
304307
"""
305308
Decode 128 bytes to a G1 point. Does not perform sub-group check.
306309
@@ -311,7 +314,7 @@ def bytes_to_G1(data: Bytes) -> Point2D:
311314
312315
Returns
313316
-------
314-
point : Point2D
317+
point : Point3D[FQ]
315318
The G1 point.
316319
317320
Raises
@@ -322,52 +325,49 @@ def bytes_to_G1(data: Bytes) -> Point2D:
322325
if len(data) != 128:
323326
raise InvalidParameter("Input should be 128 bytes long")
324327

325-
x = int.from_bytes(data[:64], "big")
326-
y = int.from_bytes(data[64:], "big")
328+
x = bytes_to_fq(data[:64])
329+
y = bytes_to_fq(data[64:])
327330

328-
if x >= P:
329-
raise InvalidParameter("Invalid field element")
330-
if y >= P:
331-
raise InvalidParameter("Invalid field element")
331+
if x >= FQ.field_modulus:
332+
raise InvalidParameter("x >= field modulus")
333+
if y >= FQ.field_modulus:
334+
raise InvalidParameter("y >= field modulus")
332335

336+
z = 1
333337
if x == 0 and y == 0:
334-
return None
335-
336-
point = (FQ(x), FQ(y))
338+
z = 0
339+
point = FQ(x), FQ(y), FQ(z)
337340

338-
# Check if the point is on the curve
339341
if not is_on_curve(point, b):
340342
raise InvalidParameter("Point is not on curve")
341343

342344
return point
343345

344346

345-
def G1_to_bytes(point: Point2D) -> Bytes:
347+
def g1_to_bytes(
348+
g1_point: Point3D[FQ],
349+
) -> Bytes:
346350
"""
347351
Encode a G1 point to 128 bytes.
348352
349353
Parameters
350354
----------
351-
point :
355+
g1_point :
352356
The G1 point to encode.
353357
354358
Returns
355359
-------
356360
data : Bytes
357361
The encoded data.
358362
"""
359-
if point is None:
360-
return b"\x00" * 128
361-
362-
x, y = point
363+
g1_normalized = normalize(g1_point)
364+
x, y = g1_normalized
365+
return b"".join([int(x).to_bytes(64, "big"), int(y).to_bytes(64, "big")])
363366

364-
x_bytes = int(x).to_bytes(64, "big")
365-
y_bytes = int(y).to_bytes(64, "big")
366367

367-
return x_bytes + y_bytes
368-
369-
370-
def decode_G1_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
368+
def decode_g1_scalar_pair(
369+
data: Bytes,
370+
) -> Tuple[Point3D[FQ], int]:
371371
"""
372372
Decode 160 bytes to a G1 point and a scalar.
373373
@@ -378,7 +378,7 @@ def decode_G1_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
378378
379379
Returns
380380
-------
381-
point : Tuple[Point2D, int]
381+
point : Tuple[Point3D[FQ], int]
382382
The G1 point and the scalar.
383383
384384
Raises
@@ -389,31 +389,27 @@ def decode_G1_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
389389
if len(data) != 160:
390390
InvalidParameter("Input should be 160 bytes long")
391391

392-
p = bytes_to_G1(buffer_read(data, U256(0), U256(128)))
393-
if multiply(p, curve_order) is not None:
392+
point = bytes_to_g1(data[:128])
393+
if not is_inf(bls12_multiply(point, curve_order)):
394394
raise InvalidParameter("Sub-group check failed.")
395395

396396
m = int.from_bytes(buffer_read(data, U256(128), U256(32)), "big")
397397

398-
return p, m
398+
return point, m
399399

400400

401-
def bytes_to_FQ(
402-
data: Bytes, optimized: bool = False
403-
) -> Union[FQ, OPTIMIZED_FQ]:
401+
def bytes_to_fq(data: Bytes) -> FQ:
404402
"""
405403
Decode 64 bytes to a FQ element.
406404
407405
Parameters
408406
----------
409407
data :
410408
The bytes data to decode.
411-
optimized :
412-
Whether to use the optimized FQ implementation.
413409
414410
Returns
415411
-------
416-
fq : Union[FQ, OPTIMIZED_FQ]
412+
fq : FQ
417413
The FQ element.
418414
419415
Raises
@@ -426,31 +422,24 @@ def bytes_to_FQ(
426422

427423
c = int.from_bytes(data[:64], "big")
428424

429-
if c >= P:
425+
if c >= FQ.field_modulus:
430426
raise InvalidParameter("Invalid field element")
431427

432-
if optimized:
433-
return OPTIMIZED_FQ(c)
434-
else:
435-
return FQ(c)
428+
return FQ(c)
436429

437430

438-
def bytes_to_FQ2(
439-
data: Bytes, optimized: bool = False
440-
) -> Union[FQ2, OPTIMIZED_FQ2]:
431+
def bytes_to_fq2(data: Bytes) -> FQ2:
441432
"""
442-
Decode 128 bytes to a FQ2 element.
433+
Decode 128 bytes to an FQ2 element.
443434
444435
Parameters
445436
----------
446437
data :
447438
The bytes data to decode.
448-
optimized :
449-
Whether to use the optimized FQ2 implementation.
450439
451440
Returns
452441
-------
453-
fq2 : Union[FQ2, OPTIMIZED_FQ2]
442+
fq2 : FQ2
454443
The FQ2 element.
455444
456445
Raises
@@ -463,18 +452,17 @@ def bytes_to_FQ2(
463452
c_0 = int.from_bytes(data[:64], "big")
464453
c_1 = int.from_bytes(data[64:], "big")
465454

466-
if c_0 >= P:
455+
if c_0 >= FQ.field_modulus:
467456
raise InvalidParameter("Invalid field element")
468-
if c_1 >= P:
457+
if c_1 >= FQ.field_modulus:
469458
raise InvalidParameter("Invalid field element")
470459

471-
if optimized:
472-
return OPTIMIZED_FQ2((c_0, c_1))
473-
else:
474-
return FQ2((c_0, c_1))
460+
return FQ2((c_0, c_1))
475461

476462

477-
def bytes_to_G2(data: Bytes) -> Point2D:
463+
def bytes_to_g2(
464+
data: Bytes,
465+
) -> Point3D[FQ2]:
478466
"""
479467
Decode 256 bytes to a G2 point. Does not perform sub-group check.
480468
@@ -485,7 +473,7 @@ def bytes_to_G2(data: Bytes) -> Point2D:
485473
486474
Returns
487475
-------
488-
point : Point2D
476+
point : Point3D[FQ2]
489477
The G2 point.
490478
491479
Raises
@@ -496,14 +484,14 @@ def bytes_to_G2(data: Bytes) -> Point2D:
496484
if len(data) != 256:
497485
raise InvalidParameter("G2 should be 256 bytes long")
498486

499-
x = bytes_to_FQ2(data[:128])
500-
y = bytes_to_FQ2(data[128:])
487+
x = bytes_to_fq2(data[:128])
488+
y = bytes_to_fq2(data[128:])
501489

502-
assert isinstance(x, FQ2) and isinstance(y, FQ2)
490+
z = (1, 0)
503491
if x == FQ2((0, 0)) and y == FQ2((0, 0)):
504-
return None
492+
z = (0, 0)
505493

506-
point = (x, y)
494+
point = x, y, FQ2(z)
507495

508496
# Check if the point is on the curve
509497
if not is_on_curve(point, b2):
@@ -526,33 +514,38 @@ def FQ2_to_bytes(fq2: FQ2) -> Bytes:
526514
data : Bytes
527515
The encoded data.
528516
"""
529-
c_0, c_1 = fq2.coeffs
530-
return int(c_0).to_bytes(64, "big") + int(c_1).to_bytes(64, "big")
531-
532-
533-
def G2_to_bytes(point: Point2D) -> Bytes:
517+
coord0, coord1 = fq2.coeffs
518+
return b"".join(
519+
[
520+
int(coord0).to_bytes(64, "big"),
521+
int(coord1).to_bytes(64, "big"),
522+
]
523+
)
524+
525+
526+
def g2_to_bytes(
527+
g2_point: Point3D[FQ2],
528+
) -> Bytes:
534529
"""
535530
Encode a G2 point to 256 bytes.
536531
537532
Parameters
538533
----------
539-
point :
534+
g2_point :
540535
The G2 point to encode.
541536
542537
Returns
543538
-------
544539
data : Bytes
545540
The encoded data.
546541
"""
547-
if point is None:
548-
return b"\x00" * 256
542+
x_coords, y_coords = normalize(g2_point)
543+
return b"".join([FQ2_to_bytes(x_coords), FQ2_to_bytes(y_coords)])
549544

550-
x, y = point
551545

552-
return FQ2_to_bytes(x) + FQ2_to_bytes(y)
553-
554-
555-
def decode_G2_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
546+
def decode_g2_scalar_pair(
547+
data: Bytes,
548+
) -> Tuple[Point3D[FQ2], int]:
556549
"""
557550
Decode 288 bytes to a G2 point and a scalar.
558551
@@ -563,7 +556,7 @@ def decode_G2_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
563556
564557
Returns
565558
-------
566-
point : Tuple[Point2D, int]
559+
point : Tuple[Point3D[FQ2], int]
567560
The G2 point and the scalar.
568561
569562
Raises
@@ -574,10 +567,11 @@ def decode_G2_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
574567
if len(data) != 288:
575568
InvalidParameter("Input should be 288 bytes long")
576569

577-
p = bytes_to_G2(buffer_read(data, U256(0), U256(256)))
578-
if multiply(p, curve_order) is not None:
579-
raise InvalidParameter("Sub-group check failed.")
570+
point = bytes_to_g2(data[:256])
571+
572+
if not is_inf(bls12_multiply(point, curve_order)):
573+
raise InvalidParameter("Point failed sub-group check.")
580574

581-
m = int.from_bytes(buffer_read(data, U256(256), U256(32)), "big")
575+
n = int.from_bytes(data[256 : 256 + 32], "big")
582576

583-
return p, m
577+
return point, n

0 commit comments

Comments
 (0)