11import hashlib
22import binascii
33
4+ # Set DEBUG to True to get a detailed debug output including
5+ # intermediate values during key generation, signing, and
6+ # verification. This is implemented via calls to the
7+ # debug_print_vars() function.
8+ #
9+ # If you want to print values on an individual basis, use
10+ # the pretty() function, e.g., print(pretty(foo)).
11+ DEBUG = False
12+
413p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
514n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
615
@@ -24,13 +33,13 @@ def y(P):
2433 return P [1 ]
2534
2635def point_add (P1 , P2 ):
27- if ( P1 is None ) :
36+ if P1 is None :
2837 return P2
29- if ( P2 is None ) :
38+ if P2 is None :
3039 return P1
31- if (x (P1 ) == x (P2 ) and y (P1 ) != y (P2 )):
40+ if (x (P1 ) == x (P2 )) and ( y (P1 ) != y (P2 )):
3241 return None
33- if ( P1 == P2 ) :
42+ if P1 == P2 :
3443 lam = (3 * x (P1 ) * x (P1 ) * pow (2 * y (P1 ), p - 2 , p )) % p
3544 else :
3645 lam = ((y (P2 ) - y (P1 )) * pow (x (P2 ) - x (P1 ), p - 2 , p )) % p
@@ -40,7 +49,7 @@ def point_add(P1, P2):
4049def point_mul (P , n ):
4150 R = None
4251 for i in range (256 ):
43- if (( n >> i ) & 1 ) :
52+ if (n >> i ) & 1 :
4453 R = point_add (R , P )
4554 P = point_add (P , P )
4655 return R
@@ -62,14 +71,14 @@ def lift_x_square_y(b):
6271 y = pow (y_sq , (p + 1 ) // 4 , p )
6372 if pow (y , 2 , p ) != y_sq :
6473 return None
65- return [ x , y ]
74+ return ( x , y )
6675
6776def lift_x_even_y (b ):
6877 P = lift_x_square_y (b )
6978 if P is None :
7079 return None
7180 else :
72- return [ x (P ), y (P ) if y (P ) % 2 == 0 else p - y (P )]
81+ return ( x (P ), y (P ) if y (P ) % 2 == 0 else p - y (P ))
7382
7483def int_from_bytes (b ):
7584 return int .from_bytes (b , byteorder = "big" )
@@ -81,38 +90,39 @@ def is_square(x):
8190 return pow (x , (p - 1 ) // 2 , p ) == 1
8291
8392def has_square_y (P ):
84- return not is_infinity (P ) and is_square (y (P ))
93+ return ( not is_infinity (P ) ) and is_square (y (P ))
8594
8695def has_even_y (P ):
8796 return y (P ) % 2 == 0
8897
8998def pubkey_gen (seckey ):
90- x = int_from_bytes (seckey )
91- if not (1 <= x <= n - 1 ):
99+ d0 = int_from_bytes (seckey )
100+ if not (1 <= d0 <= n - 1 ):
92101 raise ValueError ('The secret key must be an integer in the range 1..n-1.' )
93- P = point_mul (G , x )
102+ P = point_mul (G , d0 )
94103 return bytes_from_point (P )
95104
96- def schnorr_sign (msg , seckey0 , aux_rand ):
105+ def schnorr_sign (msg , seckey , aux_rand ):
97106 if len (msg ) != 32 :
98107 raise ValueError ('The message must be a 32-byte array.' )
99- seckey0 = int_from_bytes (seckey0 )
100- if not (1 <= seckey0 <= n - 1 ):
108+ d0 = int_from_bytes (seckey )
109+ if not (1 <= d0 <= n - 1 ):
101110 raise ValueError ('The secret key must be an integer in the range 1..n-1.' )
102111 if len (aux_rand ) != 32 :
103112 raise ValueError ('aux_rand must be 32 bytes instead of %i.' % len (aux_rand ))
104- P = point_mul (G , seckey0 )
105- seckey = seckey0 if has_even_y (P ) else n - seckey0
106- t = xor_bytes (bytes_from_int (seckey ), tagged_hash ("BIP340/aux" , aux_rand ))
113+ P = point_mul (G , d0 )
114+ d = d0 if has_even_y (P ) else n - d0
115+ t = xor_bytes (bytes_from_int (d ), tagged_hash ("BIP340/aux" , aux_rand ))
107116 k0 = int_from_bytes (tagged_hash ("BIP340/nonce" , t + bytes_from_point (P ) + msg )) % n
108117 if k0 == 0 :
109118 raise RuntimeError ('Failure. This happens only with negligible probability.' )
110119 R = point_mul (G , k0 )
111120 k = n - k0 if not has_square_y (R ) else k0
112121 e = int_from_bytes (tagged_hash ("BIP340/challenge" , bytes_from_point (R ) + bytes_from_point (P ) + msg )) % n
113- sig = bytes_from_point (R ) + bytes_from_int ((k + e * seckey ) % n )
122+ sig = bytes_from_point (R ) + bytes_from_int ((k + e * d ) % n )
123+ debug_print_vars ()
114124 if not schnorr_verify (msg , bytes_from_point (P ), sig ):
115- raise RuntimeError ('The signature does not pass verification.' )
125+ raise RuntimeError ('The created signature does not pass verification.' )
116126 return sig
117127
118128def schnorr_verify (msg , pubkey , sig ):
@@ -123,26 +133,29 @@ def schnorr_verify(msg, pubkey, sig):
123133 if len (sig ) != 64 :
124134 raise ValueError ('The signature must be a 64-byte array.' )
125135 P = lift_x_even_y (pubkey )
126- if (P is None ):
127- return False
128136 r = int_from_bytes (sig [0 :32 ])
129137 s = int_from_bytes (sig [32 :64 ])
130- if (r >= p or s >= n ):
138+ if (P is None ) or (r >= p ) or (s >= n ):
139+ debug_print_vars ()
131140 return False
132141 e = int_from_bytes (tagged_hash ("BIP340/challenge" , sig [0 :32 ] + pubkey + msg )) % n
133142 R = point_add (point_mul (G , s ), point_mul (P , n - e ))
134- if R is None or not has_square_y (R ) or x (R ) != r :
143+ if (R is None ) or (not has_square_y (R )) or (x (R ) != r ):
144+ debug_print_vars ()
135145 return False
146+ debug_print_vars ()
136147 return True
137148
138149#
139150# The following code is only used to verify the test vectors.
140151#
141152import csv
153+ import os
154+ import sys
142155
143156def test_vectors ():
144157 all_passed = True
145- with open ('test-vectors.csv' , newline = '' ) as csvfile :
158+ with open (os . path . join ( sys . path [ 0 ], 'test-vectors.csv' ) , newline = '' ) as csvfile :
146159 reader = csv .reader (csvfile )
147160 reader .__next__ ()
148161 for row in reader :
@@ -151,7 +164,7 @@ def test_vectors():
151164 msg = bytes .fromhex (msg )
152165 sig = bytes .fromhex (sig )
153166 result = result == 'TRUE'
154- print ('\n Test vector #%-3i: ' % int ( index ))
167+ print ('\n Test vector' , ( '#' + index ). rjust ( 3 , ' ' ) + ':' )
155168 if seckey != '' :
156169 seckey = bytes .fromhex (seckey )
157170 pubkey_actual = pubkey_gen (seckey )
@@ -160,13 +173,17 @@ def test_vectors():
160173 print (' Expected key:' , pubkey .hex ().upper ())
161174 print (' Actual key:' , pubkey_actual .hex ().upper ())
162175 aux_rand = bytes .fromhex (aux_rand )
163- sig_actual = schnorr_sign (msg , seckey , aux_rand )
164- if sig == sig_actual :
165- print (' * Passed signing test.' )
166- else :
167- print (' * Failed signing test.' )
168- print (' Expected signature:' , sig .hex ().upper ())
169- print (' Actual signature:' , sig_actual .hex ().upper ())
176+ try :
177+ sig_actual = schnorr_sign (msg , seckey , aux_rand )
178+ if sig == sig_actual :
179+ print (' * Passed signing test.' )
180+ else :
181+ print (' * Failed signing test.' )
182+ print (' Expected signature:' , sig .hex ().upper ())
183+ print (' Actual signature:' , sig_actual .hex ().upper ())
184+ all_passed = False
185+ except RuntimeError as e :
186+ print (' * Signing test raised exception:' , e )
170187 all_passed = False
171188 result_actual = schnorr_verify (msg , pubkey , sig )
172189 if result == result_actual :
@@ -185,5 +202,26 @@ def test_vectors():
185202 print ('Some test vectors failed.' )
186203 return all_passed
187204
205+ #
206+ # The following code is only used for debugging
207+ #
208+ import inspect
209+
210+ def pretty (v ):
211+ if isinstance (v , bytes ):
212+ return '0x' + v .hex ()
213+ if isinstance (v , int ):
214+ return pretty (bytes_from_int (v ))
215+ if isinstance (v , tuple ):
216+ return tuple (map (pretty , v ))
217+ return v
218+
219+ def debug_print_vars ():
220+ if DEBUG :
221+ frame = inspect .currentframe ().f_back
222+ print (' Variables in function ' , frame .f_code .co_name , ' at line ' , frame .f_lineno , ':' , sep = '' )
223+ for var_name , var_val in frame .f_locals .items ():
224+ print (' ' + var_name .rjust (11 , ' ' ), '==' , pretty (var_val ))
225+
188226if __name__ == '__main__' :
189227 test_vectors ()
0 commit comments