1+ from typing import Tuple , Optional , Any
12import hashlib
23import binascii
34
1718# represented by the None keyword.
1819G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 , 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 )
1920
21+ Point = Tuple [int , int ]
22+
2023# This implementation can be sped up by storing the midstate after hashing
2124# tag_hash instead of rehashing it all the time.
22- def tagged_hash (tag , msg ) :
25+ def tagged_hash (tag : str , msg : bytes ) -> bytes :
2326 tag_hash = hashlib .sha256 (tag .encode ()).digest ()
2427 return hashlib .sha256 (tag_hash + tag_hash + msg ).digest ()
2528
26- def is_infinity (P ) :
29+ def is_infinity (P : Optional [ Point ]) -> bool :
2730 return P is None
2831
29- def x (P ) :
32+ def x (P : Point ) -> int :
3033 return P [0 ]
3134
32- def y (P ) :
35+ def y (P : Point ) -> int :
3336 return P [1 ]
3437
35- def point_add (P1 , P2 ) :
38+ def point_add (P1 : Optional [ Point ] , P2 : Optional [ Point ]) -> Optional [ Point ] :
3639 if P1 is None :
3740 return P2
3841 if P2 is None :
@@ -46,24 +49,24 @@ def point_add(P1, P2):
4649 x3 = (lam * lam - x (P1 ) - x (P2 )) % p
4750 return (x3 , (lam * (x (P1 ) - x3 ) - y (P1 )) % p )
4851
49- def point_mul (P , n ) :
52+ def point_mul (P : Optional [ Point ] , n : int ) -> Optional [ Point ] :
5053 R = None
5154 for i in range (256 ):
5255 if (n >> i ) & 1 :
5356 R = point_add (R , P )
5457 P = point_add (P , P )
5558 return R
5659
57- def bytes_from_int (x ) :
60+ def bytes_from_int (x : int ) -> bytes :
5861 return x .to_bytes (32 , byteorder = "big" )
5962
60- def bytes_from_point (P ) :
63+ def bytes_from_point (P : Point ) -> bytes :
6164 return bytes_from_int (x (P ))
6265
63- def xor_bytes (b0 , b1 ) :
66+ def xor_bytes (b0 : bytes , b1 : bytes ) -> bytes :
6467 return bytes (x ^ y for (x , y ) in zip (b0 , b1 ))
6568
66- def lift_x_square_y (b ) :
69+ def lift_x_square_y (b : bytes ) -> Optional [ Point ] :
6770 x = int_from_bytes (b )
6871 if x >= p :
6972 return None
@@ -73,36 +76,40 @@ def lift_x_square_y(b):
7376 return None
7477 return (x , y )
7578
76- def lift_x_even_y (b ) :
79+ def lift_x_even_y (b : bytes ) -> Optional [ Point ] :
7780 P = lift_x_square_y (b )
7881 if P is None :
7982 return None
8083 else :
8184 return (x (P ), y (P ) if y (P ) % 2 == 0 else p - y (P ))
8285
83- def int_from_bytes (b ) :
86+ def int_from_bytes (b : bytes ) -> int :
8487 return int .from_bytes (b , byteorder = "big" )
8588
86- def hash_sha256 (b ) :
89+ def hash_sha256 (b : bytes ) -> bytes :
8790 return hashlib .sha256 (b ).digest ()
8891
89- def is_square (x ) :
90- return pow (x , (p - 1 ) // 2 , p ) == 1
92+ def is_square (x : int ) -> bool :
93+ return int ( pow (x , (p - 1 ) // 2 , p ) ) == 1
9194
92- def has_square_y (P ):
93- return (not is_infinity (P )) and is_square (y (P ))
95+ def has_square_y (P : Optional [Point ]) -> bool :
96+ infinity = is_infinity (P )
97+ if infinity : return False
98+ assert P is not None
99+ return is_square (y (P ))
94100
95- def has_even_y (P ) :
101+ def has_even_y (P : Point ) -> bool :
96102 return y (P ) % 2 == 0
97103
98- def pubkey_gen (seckey ) :
104+ def pubkey_gen (seckey : bytes ) -> bytes :
99105 d0 = int_from_bytes (seckey )
100106 if not (1 <= d0 <= n - 1 ):
101107 raise ValueError ('The secret key must be an integer in the range 1..n-1.' )
102108 P = point_mul (G , d0 )
109+ assert P is not None
103110 return bytes_from_point (P )
104111
105- def schnorr_sign (msg , seckey , aux_rand ) :
112+ def schnorr_sign (msg : bytes , seckey : bytes , aux_rand : bytes ) -> bytes :
106113 if len (msg ) != 32 :
107114 raise ValueError ('The message must be a 32-byte array.' )
108115 d0 = int_from_bytes (seckey )
@@ -111,12 +118,14 @@ def schnorr_sign(msg, seckey, aux_rand):
111118 if len (aux_rand ) != 32 :
112119 raise ValueError ('aux_rand must be 32 bytes instead of %i.' % len (aux_rand ))
113120 P = point_mul (G , d0 )
121+ assert P is not None
114122 d = d0 if has_even_y (P ) else n - d0
115123 t = xor_bytes (bytes_from_int (d ), tagged_hash ("BIP340/aux" , aux_rand ))
116124 k0 = int_from_bytes (tagged_hash ("BIP340/nonce" , t + bytes_from_point (P ) + msg )) % n
117125 if k0 == 0 :
118126 raise RuntimeError ('Failure. This happens only with negligible probability.' )
119127 R = point_mul (G , k0 )
128+ assert R is not None
120129 k = n - k0 if not has_square_y (R ) else k0
121130 e = int_from_bytes (tagged_hash ("BIP340/challenge" , bytes_from_point (R ) + bytes_from_point (P ) + msg )) % n
122131 sig = bytes_from_point (R ) + bytes_from_int ((k + e * d ) % n )
@@ -125,7 +134,7 @@ def schnorr_sign(msg, seckey, aux_rand):
125134 raise RuntimeError ('The created signature does not pass verification.' )
126135 return sig
127136
128- def schnorr_verify (msg , pubkey , sig ) :
137+ def schnorr_verify (msg : bytes , pubkey : bytes , sig : bytes ) -> bool :
129138 if len (msg ) != 32 :
130139 raise ValueError ('The message must be a 32-byte array.' )
131140 if len (pubkey ) != 32 :
@@ -153,26 +162,26 @@ def schnorr_verify(msg, pubkey, sig):
153162import os
154163import sys
155164
156- def test_vectors ():
165+ def test_vectors () -> bool :
157166 all_passed = True
158167 with open (os .path .join (sys .path [0 ], 'test-vectors.csv' ), newline = '' ) as csvfile :
159168 reader = csv .reader (csvfile )
160169 reader .__next__ ()
161170 for row in reader :
162- (index , seckey , pubkey , aux_rand , msg , sig , result , comment ) = row
163- pubkey = bytes .fromhex (pubkey )
164- msg = bytes .fromhex (msg )
165- sig = bytes .fromhex (sig )
166- result = result == 'TRUE'
171+ (index , seckey_hex , pubkey_hex , aux_rand_hex , msg_hex , sig_hex , result_str , comment ) = row
172+ pubkey = bytes .fromhex (pubkey_hex )
173+ msg = bytes .fromhex (msg_hex )
174+ sig = bytes .fromhex (sig_hex )
175+ result = result_str == 'TRUE'
167176 print ('\n Test vector' , ('#' + index ).rjust (3 , ' ' ) + ':' )
168- if seckey != '' :
169- seckey = bytes .fromhex (seckey )
177+ if seckey_hex != '' :
178+ seckey = bytes .fromhex (seckey_hex )
170179 pubkey_actual = pubkey_gen (seckey )
171180 if pubkey != pubkey_actual :
172181 print (' * Failed key generation.' )
173182 print (' Expected key:' , pubkey .hex ().upper ())
174183 print (' Actual key:' , pubkey_actual .hex ().upper ())
175- aux_rand = bytes .fromhex (aux_rand )
184+ aux_rand = bytes .fromhex (aux_rand_hex )
176185 try :
177186 sig_actual = schnorr_sign (msg , seckey , aux_rand )
178187 if sig == sig_actual :
@@ -207,7 +216,7 @@ def test_vectors():
207216#
208217import inspect
209218
210- def pretty (v ) :
219+ def pretty (v : Any ) -> Any :
211220 if isinstance (v , bytes ):
212221 return '0x' + v .hex ()
213222 if isinstance (v , int ):
@@ -216,9 +225,12 @@ def pretty(v):
216225 return tuple (map (pretty , v ))
217226 return v
218227
219- def debug_print_vars ():
228+ def debug_print_vars () -> None :
220229 if DEBUG :
221- frame = inspect .currentframe ().f_back
230+ current_frame = inspect .currentframe ()
231+ assert current_frame is not None
232+ frame = current_frame .f_back
233+ assert frame is not None
222234 print (' Variables in function ' , frame .f_code .co_name , ' at line ' , frame .f_lineno , ':' , sep = '' )
223235 for var_name , var_val in frame .f_locals .items ():
224236 print (' ' + var_name .rjust (11 , ' ' ), '==' , pretty (var_val ))
0 commit comments