1
1
import hashlib
2
2
import binascii
3
3
4
+ # Set DEBUG to True to get a detailed debug output including
5
+ # intermediate values during key generation, signing, and
6
+ # verification. This is implemented via calls to the
7
+ # debug_print_vars() function.
8
+ #
9
+ # If you want to print values on an individual basis, use
10
+ # the pretty() function, e.g., print(pretty(foo)).
11
+ DEBUG = False
12
+
4
13
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
5
14
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
6
15
@@ -24,13 +33,13 @@ def y(P):
24
33
return P [1 ]
25
34
26
35
def point_add (P1 , P2 ):
27
- if ( P1 is None ) :
36
+ if P1 is None :
28
37
return P2
29
- if ( P2 is None ) :
38
+ if P2 is None :
30
39
return P1
31
- if (x (P1 ) == x (P2 ) and y (P1 ) != y (P2 )):
40
+ if (x (P1 ) == x (P2 )) and ( y (P1 ) != y (P2 )):
32
41
return None
33
- if ( P1 == P2 ) :
42
+ if P1 == P2 :
34
43
lam = (3 * x (P1 ) * x (P1 ) * pow (2 * y (P1 ), p - 2 , p )) % p
35
44
else :
36
45
lam = ((y (P2 ) - y (P1 )) * pow (x (P2 ) - x (P1 ), p - 2 , p )) % p
@@ -40,7 +49,7 @@ def point_add(P1, P2):
40
49
def point_mul (P , n ):
41
50
R = None
42
51
for i in range (256 ):
43
- if (( n >> i ) & 1 ) :
52
+ if (n >> i ) & 1 :
44
53
R = point_add (R , P )
45
54
P = point_add (P , P )
46
55
return R
@@ -62,14 +71,14 @@ def lift_x_square_y(b):
62
71
y = pow (y_sq , (p + 1 ) // 4 , p )
63
72
if pow (y , 2 , p ) != y_sq :
64
73
return None
65
- return [ x , y ]
74
+ return ( x , y )
66
75
67
76
def lift_x_even_y (b ):
68
77
P = lift_x_square_y (b )
69
78
if P is None :
70
79
return None
71
80
else :
72
- return [ x (P ), y (P ) if y (P ) % 2 == 0 else p - y (P )]
81
+ return ( x (P ), y (P ) if y (P ) % 2 == 0 else p - y (P ))
73
82
74
83
def int_from_bytes (b ):
75
84
return int .from_bytes (b , byteorder = "big" )
@@ -81,38 +90,39 @@ def is_square(x):
81
90
return pow (x , (p - 1 ) // 2 , p ) == 1
82
91
83
92
def has_square_y (P ):
84
- return not is_infinity (P ) and is_square (y (P ))
93
+ return ( not is_infinity (P ) ) and is_square (y (P ))
85
94
86
95
def has_even_y (P ):
87
96
return y (P ) % 2 == 0
88
97
89
98
def pubkey_gen (seckey ):
90
- x = int_from_bytes (seckey )
91
- if not (1 <= x <= n - 1 ):
99
+ d0 = int_from_bytes (seckey )
100
+ if not (1 <= d0 <= n - 1 ):
92
101
raise ValueError ('The secret key must be an integer in the range 1..n-1.' )
93
- P = point_mul (G , x )
102
+ P = point_mul (G , d0 )
94
103
return bytes_from_point (P )
95
104
96
- def schnorr_sign (msg , seckey0 , aux_rand ):
105
+ def schnorr_sign (msg , seckey , aux_rand ):
97
106
if len (msg ) != 32 :
98
107
raise ValueError ('The message must be a 32-byte array.' )
99
- seckey0 = int_from_bytes (seckey0 )
100
- if not (1 <= seckey0 <= n - 1 ):
108
+ d0 = int_from_bytes (seckey )
109
+ if not (1 <= d0 <= n - 1 ):
101
110
raise ValueError ('The secret key must be an integer in the range 1..n-1.' )
102
111
if len (aux_rand ) != 32 :
103
112
raise ValueError ('aux_rand must be 32 bytes instead of %i.' % len (aux_rand ))
104
- P = point_mul (G , seckey0 )
105
- seckey = seckey0 if has_even_y (P ) else n - seckey0
106
- t = xor_bytes (bytes_from_int (seckey ), tagged_hash ("BIP340/aux" , aux_rand ))
113
+ P = point_mul (G , d0 )
114
+ d = d0 if has_even_y (P ) else n - d0
115
+ t = xor_bytes (bytes_from_int (d ), tagged_hash ("BIP340/aux" , aux_rand ))
107
116
k0 = int_from_bytes (tagged_hash ("BIP340/nonce" , t + bytes_from_point (P ) + msg )) % n
108
117
if k0 == 0 :
109
118
raise RuntimeError ('Failure. This happens only with negligible probability.' )
110
119
R = point_mul (G , k0 )
111
120
k = n - k0 if not has_square_y (R ) else k0
112
121
e = int_from_bytes (tagged_hash ("BIP340/challenge" , bytes_from_point (R ) + bytes_from_point (P ) + msg )) % n
113
- sig = bytes_from_point (R ) + bytes_from_int ((k + e * seckey ) % n )
122
+ sig = bytes_from_point (R ) + bytes_from_int ((k + e * d ) % n )
123
+ debug_print_vars ()
114
124
if not schnorr_verify (msg , bytes_from_point (P ), sig ):
115
- raise RuntimeError ('The signature does not pass verification.' )
125
+ raise RuntimeError ('The created signature does not pass verification.' )
116
126
return sig
117
127
118
128
def schnorr_verify (msg , pubkey , sig ):
@@ -123,26 +133,29 @@ def schnorr_verify(msg, pubkey, sig):
123
133
if len (sig ) != 64 :
124
134
raise ValueError ('The signature must be a 64-byte array.' )
125
135
P = lift_x_even_y (pubkey )
126
- if (P is None ):
127
- return False
128
136
r = int_from_bytes (sig [0 :32 ])
129
137
s = int_from_bytes (sig [32 :64 ])
130
- if (r >= p or s >= n ):
138
+ if (P is None ) or (r >= p ) or (s >= n ):
139
+ debug_print_vars ()
131
140
return False
132
141
e = int_from_bytes (tagged_hash ("BIP340/challenge" , sig [0 :32 ] + pubkey + msg )) % n
133
142
R = point_add (point_mul (G , s ), point_mul (P , n - e ))
134
- if R is None or not has_square_y (R ) or x (R ) != r :
143
+ if (R is None ) or (not has_square_y (R )) or (x (R ) != r ):
144
+ debug_print_vars ()
135
145
return False
146
+ debug_print_vars ()
136
147
return True
137
148
138
149
#
139
150
# The following code is only used to verify the test vectors.
140
151
#
141
152
import csv
153
+ import os
154
+ import sys
142
155
143
156
def test_vectors ():
144
157
all_passed = True
145
- with open ('test-vectors.csv' , newline = '' ) as csvfile :
158
+ with open (os . path . join ( sys . path [ 0 ], 'test-vectors.csv' ) , newline = '' ) as csvfile :
146
159
reader = csv .reader (csvfile )
147
160
reader .__next__ ()
148
161
for row in reader :
@@ -151,7 +164,7 @@ def test_vectors():
151
164
msg = bytes .fromhex (msg )
152
165
sig = bytes .fromhex (sig )
153
166
result = result == 'TRUE'
154
- print ('\n Test vector #%-3i: ' % int ( index ))
167
+ print ('\n Test vector' , ( '#' + index ). rjust ( 3 , ' ' ) + ':' )
155
168
if seckey != '' :
156
169
seckey = bytes .fromhex (seckey )
157
170
pubkey_actual = pubkey_gen (seckey )
@@ -160,13 +173,17 @@ def test_vectors():
160
173
print (' Expected key:' , pubkey .hex ().upper ())
161
174
print (' Actual key:' , pubkey_actual .hex ().upper ())
162
175
aux_rand = bytes .fromhex (aux_rand )
163
- sig_actual = schnorr_sign (msg , seckey , aux_rand )
164
- if sig == sig_actual :
165
- print (' * Passed signing test.' )
166
- else :
167
- print (' * Failed signing test.' )
168
- print (' Expected signature:' , sig .hex ().upper ())
169
- print (' Actual signature:' , sig_actual .hex ().upper ())
176
+ try :
177
+ sig_actual = schnorr_sign (msg , seckey , aux_rand )
178
+ if sig == sig_actual :
179
+ print (' * Passed signing test.' )
180
+ else :
181
+ print (' * Failed signing test.' )
182
+ print (' Expected signature:' , sig .hex ().upper ())
183
+ print (' Actual signature:' , sig_actual .hex ().upper ())
184
+ all_passed = False
185
+ except RuntimeError as e :
186
+ print (' * Signing test raised exception:' , e )
170
187
all_passed = False
171
188
result_actual = schnorr_verify (msg , pubkey , sig )
172
189
if result == result_actual :
@@ -185,5 +202,26 @@ def test_vectors():
185
202
print ('Some test vectors failed.' )
186
203
return all_passed
187
204
205
+ #
206
+ # The following code is only used for debugging
207
+ #
208
+ import inspect
209
+
210
+ def pretty (v ):
211
+ if isinstance (v , bytes ):
212
+ return '0x' + v .hex ()
213
+ if isinstance (v , int ):
214
+ return pretty (bytes_from_int (v ))
215
+ if isinstance (v , tuple ):
216
+ return tuple (map (pretty , v ))
217
+ return v
218
+
219
+ def debug_print_vars ():
220
+ if DEBUG :
221
+ frame = inspect .currentframe ().f_back
222
+ print (' Variables in function ' , frame .f_code .co_name , ' at line ' , frame .f_lineno , ':' , sep = '' )
223
+ for var_name , var_val in frame .f_locals .items ():
224
+ print (' ' + var_name .rjust (11 , ' ' ), '==' , pretty (var_val ))
225
+
188
226
if __name__ == '__main__' :
189
227
test_vectors ()
0 commit comments