11
11
12
12
Precompile for BLS12-381 curve operations.
13
13
"""
14
- from typing import Tuple , Union
14
+
15
+ from typing import Tuple
15
16
16
17
from ethereum_types .bytes import Bytes
17
18
from ethereum_types .numeric import U256 , Uint
18
- from py_ecc .bls12_381 . bls12_381_curve import (
19
+ from py_ecc .optimized_bls12_381 . optimized_curve import (
19
20
FQ ,
20
21
FQ2 ,
21
22
b ,
22
23
b2 ,
23
24
curve_order ,
25
+ is_inf ,
24
26
is_on_curve ,
25
- multiply ,
26
27
)
27
- from py_ecc .optimized_bls12_381 .optimized_curve import FQ as OPTIMIZED_FQ
28
- from py_ecc .optimized_bls12_381 .optimized_curve import FQ2 as OPTIMIZED_FQ2
29
- from py_ecc .typing import Point2D
28
+ from py_ecc .optimized_bls12_381 .optimized_curve import (
29
+ multiply as bls12_multiply ,
30
+ )
31
+ from py_ecc .optimized_bls12_381 .optimized_curve import normalize
32
+ from py_ecc .typing import Optimized_Point3D as Point3D
30
33
31
34
from ....vm .memory import buffer_read
32
35
from ...exceptions import InvalidParameter
33
36
34
- P = FQ .field_modulus
35
-
36
37
G1_K_DISCOUNT = [
37
38
1000 ,
38
39
949 ,
300
301
MULTIPLIER = Uint (1000 )
301
302
302
303
303
- def bytes_to_G1 (data : Bytes ) -> Point2D :
304
+ def bytes_to_g1 (
305
+ data : Bytes ,
306
+ ) -> Point3D [FQ ]:
304
307
"""
305
308
Decode 128 bytes to a G1 point. Does not perform sub-group check.
306
309
@@ -311,7 +314,7 @@ def bytes_to_G1(data: Bytes) -> Point2D:
311
314
312
315
Returns
313
316
-------
314
- point : Point2D
317
+ point : Point3D[FQ]
315
318
The G1 point.
316
319
317
320
Raises
@@ -322,52 +325,49 @@ def bytes_to_G1(data: Bytes) -> Point2D:
322
325
if len (data ) != 128 :
323
326
raise InvalidParameter ("Input should be 128 bytes long" )
324
327
325
- x = int . from_bytes (data [:64 ], "big" )
326
- y = int . from_bytes (data [64 :], "big" )
328
+ x = bytes_to_fq (data [:64 ])
329
+ y = bytes_to_fq (data [64 :])
327
330
328
- if x >= P :
329
- raise InvalidParameter ("Invalid field element " )
330
- if y >= P :
331
- raise InvalidParameter ("Invalid field element " )
331
+ if x >= FQ . field_modulus :
332
+ raise InvalidParameter ("x >= field modulus " )
333
+ if y >= FQ . field_modulus :
334
+ raise InvalidParameter ("y >= field modulus " )
332
335
336
+ z = 1
333
337
if x == 0 and y == 0 :
334
- return None
335
-
336
- point = (FQ (x ), FQ (y ))
338
+ z = 0
339
+ point = FQ (x ), FQ (y ), FQ (z )
337
340
338
- # Check if the point is on the curve
339
341
if not is_on_curve (point , b ):
340
342
raise InvalidParameter ("Point is not on curve" )
341
343
342
344
return point
343
345
344
346
345
- def G1_to_bytes (point : Point2D ) -> Bytes :
347
+ def g1_to_bytes (
348
+ g1_point : Point3D [FQ ],
349
+ ) -> Bytes :
346
350
"""
347
351
Encode a G1 point to 128 bytes.
348
352
349
353
Parameters
350
354
----------
351
- point :
355
+ g1_point :
352
356
The G1 point to encode.
353
357
354
358
Returns
355
359
-------
356
360
data : Bytes
357
361
The encoded data.
358
362
"""
359
- if point is None :
360
- return b"\x00 " * 128
361
-
362
- x , y = point
363
+ g1_normalized = normalize (g1_point )
364
+ x , y = g1_normalized
365
+ return b"" .join ([int (x ).to_bytes (64 , "big" ), int (y ).to_bytes (64 , "big" )])
363
366
364
- x_bytes = int (x ).to_bytes (64 , "big" )
365
- y_bytes = int (y ).to_bytes (64 , "big" )
366
367
367
- return x_bytes + y_bytes
368
-
369
-
370
- def decode_G1_scalar_pair (data : Bytes ) -> Tuple [Point2D , int ]:
368
+ def decode_g1_scalar_pair (
369
+ data : Bytes ,
370
+ ) -> Tuple [Point3D [FQ ], int ]:
371
371
"""
372
372
Decode 160 bytes to a G1 point and a scalar.
373
373
@@ -378,7 +378,7 @@ def decode_G1_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
378
378
379
379
Returns
380
380
-------
381
- point : Tuple[Point2D , int]
381
+ point : Tuple[Point3D[FQ] , int]
382
382
The G1 point and the scalar.
383
383
384
384
Raises
@@ -389,31 +389,27 @@ def decode_G1_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
389
389
if len (data ) != 160 :
390
390
InvalidParameter ("Input should be 160 bytes long" )
391
391
392
- p = bytes_to_G1 ( buffer_read ( data , U256 ( 0 ), U256 ( 128 )) )
393
- if multiply ( p , curve_order ) is not None :
392
+ point = bytes_to_g1 ( data [: 128 ] )
393
+ if not is_inf ( bls12_multiply ( point , curve_order )) :
394
394
raise InvalidParameter ("Sub-group check failed." )
395
395
396
396
m = int .from_bytes (buffer_read (data , U256 (128 ), U256 (32 )), "big" )
397
397
398
- return p , m
398
+ return point , m
399
399
400
400
401
- def bytes_to_FQ (
402
- data : Bytes , optimized : bool = False
403
- ) -> Union [FQ , OPTIMIZED_FQ ]:
401
+ def bytes_to_fq (data : Bytes ) -> FQ :
404
402
"""
405
403
Decode 64 bytes to a FQ element.
406
404
407
405
Parameters
408
406
----------
409
407
data :
410
408
The bytes data to decode.
411
- optimized :
412
- Whether to use the optimized FQ implementation.
413
409
414
410
Returns
415
411
-------
416
- fq : Union[FQ, OPTIMIZED_FQ]
412
+ fq : FQ
417
413
The FQ element.
418
414
419
415
Raises
@@ -426,31 +422,24 @@ def bytes_to_FQ(
426
422
427
423
c = int .from_bytes (data [:64 ], "big" )
428
424
429
- if c >= P :
425
+ if c >= FQ . field_modulus :
430
426
raise InvalidParameter ("Invalid field element" )
431
427
432
- if optimized :
433
- return OPTIMIZED_FQ (c )
434
- else :
435
- return FQ (c )
428
+ return FQ (c )
436
429
437
430
438
- def bytes_to_FQ2 (
439
- data : Bytes , optimized : bool = False
440
- ) -> Union [FQ2 , OPTIMIZED_FQ2 ]:
431
+ def bytes_to_fq2 (data : Bytes ) -> FQ2 :
441
432
"""
442
- Decode 128 bytes to a FQ2 element.
433
+ Decode 128 bytes to an FQ2 element.
443
434
444
435
Parameters
445
436
----------
446
437
data :
447
438
The bytes data to decode.
448
- optimized :
449
- Whether to use the optimized FQ2 implementation.
450
439
451
440
Returns
452
441
-------
453
- fq2 : Union[ FQ2, OPTIMIZED_FQ2]
442
+ fq2 : FQ2
454
443
The FQ2 element.
455
444
456
445
Raises
@@ -463,18 +452,17 @@ def bytes_to_FQ2(
463
452
c_0 = int .from_bytes (data [:64 ], "big" )
464
453
c_1 = int .from_bytes (data [64 :], "big" )
465
454
466
- if c_0 >= P :
455
+ if c_0 >= FQ . field_modulus :
467
456
raise InvalidParameter ("Invalid field element" )
468
- if c_1 >= P :
457
+ if c_1 >= FQ . field_modulus :
469
458
raise InvalidParameter ("Invalid field element" )
470
459
471
- if optimized :
472
- return OPTIMIZED_FQ2 ((c_0 , c_1 ))
473
- else :
474
- return FQ2 ((c_0 , c_1 ))
460
+ return FQ2 ((c_0 , c_1 ))
475
461
476
462
477
- def bytes_to_G2 (data : Bytes ) -> Point2D :
463
+ def bytes_to_g2 (
464
+ data : Bytes ,
465
+ ) -> Point3D [FQ2 ]:
478
466
"""
479
467
Decode 256 bytes to a G2 point. Does not perform sub-group check.
480
468
@@ -485,7 +473,7 @@ def bytes_to_G2(data: Bytes) -> Point2D:
485
473
486
474
Returns
487
475
-------
488
- point : Point2D
476
+ point : Point3D[FQ2]
489
477
The G2 point.
490
478
491
479
Raises
@@ -496,14 +484,14 @@ def bytes_to_G2(data: Bytes) -> Point2D:
496
484
if len (data ) != 256 :
497
485
raise InvalidParameter ("G2 should be 256 bytes long" )
498
486
499
- x = bytes_to_FQ2 (data [:128 ])
500
- y = bytes_to_FQ2 (data [128 :])
487
+ x = bytes_to_fq2 (data [:128 ])
488
+ y = bytes_to_fq2 (data [128 :])
501
489
502
- assert isinstance ( x , FQ2 ) and isinstance ( y , FQ2 )
490
+ z = ( 1 , 0 )
503
491
if x == FQ2 ((0 , 0 )) and y == FQ2 ((0 , 0 )):
504
- return None
492
+ z = ( 0 , 0 )
505
493
506
- point = ( x , y )
494
+ point = x , y , FQ2 ( z )
507
495
508
496
# Check if the point is on the curve
509
497
if not is_on_curve (point , b2 ):
@@ -526,33 +514,38 @@ def FQ2_to_bytes(fq2: FQ2) -> Bytes:
526
514
data : Bytes
527
515
The encoded data.
528
516
"""
529
- c_0 , c_1 = fq2 .coeffs
530
- return int (c_0 ).to_bytes (64 , "big" ) + int (c_1 ).to_bytes (64 , "big" )
531
-
532
-
533
- def G2_to_bytes (point : Point2D ) -> Bytes :
517
+ coord0 , coord1 = fq2 .coeffs
518
+ return b"" .join (
519
+ [
520
+ int (coord0 ).to_bytes (64 , "big" ),
521
+ int (coord1 ).to_bytes (64 , "big" ),
522
+ ]
523
+ )
524
+
525
+
526
+ def g2_to_bytes (
527
+ g2_point : Point3D [FQ2 ],
528
+ ) -> Bytes :
534
529
"""
535
530
Encode a G2 point to 256 bytes.
536
531
537
532
Parameters
538
533
----------
539
- point :
534
+ g2_point :
540
535
The G2 point to encode.
541
536
542
537
Returns
543
538
-------
544
539
data : Bytes
545
540
The encoded data.
546
541
"""
547
- if point is None :
548
- return b"\x00 " * 256
542
+ x_coords , y_coords = normalize ( g2_point )
543
+ return b"" . join ([ FQ2_to_bytes ( x_coords ), FQ2_to_bytes ( y_coords )])
549
544
550
- x , y = point
551
545
552
- return FQ2_to_bytes (x ) + FQ2_to_bytes (y )
553
-
554
-
555
- def decode_G2_scalar_pair (data : Bytes ) -> Tuple [Point2D , int ]:
546
+ def decode_g2_scalar_pair (
547
+ data : Bytes ,
548
+ ) -> Tuple [Point3D [FQ2 ], int ]:
556
549
"""
557
550
Decode 288 bytes to a G2 point and a scalar.
558
551
@@ -563,7 +556,7 @@ def decode_G2_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
563
556
564
557
Returns
565
558
-------
566
- point : Tuple[Point2D , int]
559
+ point : Tuple[Point3D[FQ2] , int]
567
560
The G2 point and the scalar.
568
561
569
562
Raises
@@ -574,10 +567,11 @@ def decode_G2_scalar_pair(data: Bytes) -> Tuple[Point2D, int]:
574
567
if len (data ) != 288 :
575
568
InvalidParameter ("Input should be 288 bytes long" )
576
569
577
- p = bytes_to_G2 (buffer_read (data , U256 (0 ), U256 (256 )))
578
- if multiply (p , curve_order ) is not None :
579
- raise InvalidParameter ("Sub-group check failed." )
570
+ point = bytes_to_g2 (data [:256 ])
571
+
572
+ if not is_inf (bls12_multiply (point , curve_order )):
573
+ raise InvalidParameter ("Point failed sub-group check." )
580
574
581
- m = int .from_bytes (buffer_read ( data , U256 ( 256 ), U256 ( 32 )) , "big" )
575
+ n = int .from_bytes (data [ 256 : 256 + 32 ] , "big" )
582
576
583
- return p , m
577
+ return point , n
0 commit comments