1
+ from typing import Tuple , Optional , Any
1
2
import hashlib
2
3
import binascii
3
4
17
18
# represented by the None keyword.
18
19
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 , 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 )
19
20
21
+ Point = Tuple [int , int ]
22
+
20
23
# This implementation can be sped up by storing the midstate after hashing
21
24
# tag_hash instead of rehashing it all the time.
22
- def tagged_hash (tag , msg ) :
25
+ def tagged_hash (tag : str , msg : bytes ) -> bytes :
23
26
tag_hash = hashlib .sha256 (tag .encode ()).digest ()
24
27
return hashlib .sha256 (tag_hash + tag_hash + msg ).digest ()
25
28
26
- def is_infinity (P ) :
29
+ def is_infinity (P : Optional [ Point ]) -> bool :
27
30
return P is None
28
31
29
- def x (P ) :
32
+ def x (P : Point ) -> int :
30
33
return P [0 ]
31
34
32
- def y (P ) :
35
+ def y (P : Point ) -> int :
33
36
return P [1 ]
34
37
35
- def point_add (P1 , P2 ) :
38
+ def point_add (P1 : Optional [ Point ] , P2 : Optional [ Point ]) -> Optional [ Point ] :
36
39
if P1 is None :
37
40
return P2
38
41
if P2 is None :
@@ -46,24 +49,24 @@ def point_add(P1, P2):
46
49
x3 = (lam * lam - x (P1 ) - x (P2 )) % p
47
50
return (x3 , (lam * (x (P1 ) - x3 ) - y (P1 )) % p )
48
51
49
- def point_mul (P , n ) :
52
+ def point_mul (P : Optional [ Point ] , n : int ) -> Optional [ Point ] :
50
53
R = None
51
54
for i in range (256 ):
52
55
if (n >> i ) & 1 :
53
56
R = point_add (R , P )
54
57
P = point_add (P , P )
55
58
return R
56
59
57
- def bytes_from_int (x ) :
60
+ def bytes_from_int (x : int ) -> bytes :
58
61
return x .to_bytes (32 , byteorder = "big" )
59
62
60
- def bytes_from_point (P ) :
63
+ def bytes_from_point (P : Point ) -> bytes :
61
64
return bytes_from_int (x (P ))
62
65
63
- def xor_bytes (b0 , b1 ) :
66
+ def xor_bytes (b0 : bytes , b1 : bytes ) -> bytes :
64
67
return bytes (x ^ y for (x , y ) in zip (b0 , b1 ))
65
68
66
- def lift_x_square_y (b ) :
69
+ def lift_x_square_y (b : bytes ) -> Optional [ Point ] :
67
70
x = int_from_bytes (b )
68
71
if x >= p :
69
72
return None
@@ -73,36 +76,40 @@ def lift_x_square_y(b):
73
76
return None
74
77
return (x , y )
75
78
76
- def lift_x_even_y (b ) :
79
+ def lift_x_even_y (b : bytes ) -> Optional [ Point ] :
77
80
P = lift_x_square_y (b )
78
81
if P is None :
79
82
return None
80
83
else :
81
84
return (x (P ), y (P ) if y (P ) % 2 == 0 else p - y (P ))
82
85
83
- def int_from_bytes (b ) :
86
+ def int_from_bytes (b : bytes ) -> int :
84
87
return int .from_bytes (b , byteorder = "big" )
85
88
86
- def hash_sha256 (b ) :
89
+ def hash_sha256 (b : bytes ) -> bytes :
87
90
return hashlib .sha256 (b ).digest ()
88
91
89
- def is_square (x ) :
90
- return pow (x , (p - 1 ) // 2 , p ) == 1
92
+ def is_square (x : int ) -> bool :
93
+ return int ( pow (x , (p - 1 ) // 2 , p ) ) == 1
91
94
92
- def has_square_y (P ):
93
- return (not is_infinity (P )) and is_square (y (P ))
95
+ def has_square_y (P : Optional [Point ]) -> bool :
96
+ infinity = is_infinity (P )
97
+ if infinity : return False
98
+ assert P is not None
99
+ return is_square (y (P ))
94
100
95
- def has_even_y (P ) :
101
+ def has_even_y (P : Point ) -> bool :
96
102
return y (P ) % 2 == 0
97
103
98
- def pubkey_gen (seckey ) :
104
+ def pubkey_gen (seckey : bytes ) -> bytes :
99
105
d0 = int_from_bytes (seckey )
100
106
if not (1 <= d0 <= n - 1 ):
101
107
raise ValueError ('The secret key must be an integer in the range 1..n-1.' )
102
108
P = point_mul (G , d0 )
109
+ assert P is not None
103
110
return bytes_from_point (P )
104
111
105
- def schnorr_sign (msg , seckey , aux_rand ) :
112
+ def schnorr_sign (msg : bytes , seckey : bytes , aux_rand : bytes ) -> bytes :
106
113
if len (msg ) != 32 :
107
114
raise ValueError ('The message must be a 32-byte array.' )
108
115
d0 = int_from_bytes (seckey )
@@ -111,12 +118,14 @@ def schnorr_sign(msg, seckey, aux_rand):
111
118
if len (aux_rand ) != 32 :
112
119
raise ValueError ('aux_rand must be 32 bytes instead of %i.' % len (aux_rand ))
113
120
P = point_mul (G , d0 )
121
+ assert P is not None
114
122
d = d0 if has_even_y (P ) else n - d0
115
123
t = xor_bytes (bytes_from_int (d ), tagged_hash ("BIP340/aux" , aux_rand ))
116
124
k0 = int_from_bytes (tagged_hash ("BIP340/nonce" , t + bytes_from_point (P ) + msg )) % n
117
125
if k0 == 0 :
118
126
raise RuntimeError ('Failure. This happens only with negligible probability.' )
119
127
R = point_mul (G , k0 )
128
+ assert R is not None
120
129
k = n - k0 if not has_square_y (R ) else k0
121
130
e = int_from_bytes (tagged_hash ("BIP340/challenge" , bytes_from_point (R ) + bytes_from_point (P ) + msg )) % n
122
131
sig = bytes_from_point (R ) + bytes_from_int ((k + e * d ) % n )
@@ -125,7 +134,7 @@ def schnorr_sign(msg, seckey, aux_rand):
125
134
raise RuntimeError ('The created signature does not pass verification.' )
126
135
return sig
127
136
128
- def schnorr_verify (msg , pubkey , sig ) :
137
+ def schnorr_verify (msg : bytes , pubkey : bytes , sig : bytes ) -> bool :
129
138
if len (msg ) != 32 :
130
139
raise ValueError ('The message must be a 32-byte array.' )
131
140
if len (pubkey ) != 32 :
@@ -153,26 +162,26 @@ def schnorr_verify(msg, pubkey, sig):
153
162
import os
154
163
import sys
155
164
156
- def test_vectors ():
165
+ def test_vectors () -> bool :
157
166
all_passed = True
158
167
with open (os .path .join (sys .path [0 ], 'test-vectors.csv' ), newline = '' ) as csvfile :
159
168
reader = csv .reader (csvfile )
160
169
reader .__next__ ()
161
170
for row in reader :
162
- (index , seckey , pubkey , aux_rand , msg , sig , result , comment ) = row
163
- pubkey = bytes .fromhex (pubkey )
164
- msg = bytes .fromhex (msg )
165
- sig = bytes .fromhex (sig )
166
- result = result == 'TRUE'
171
+ (index , seckey_hex , pubkey_hex , aux_rand_hex , msg_hex , sig_hex , result_str , comment ) = row
172
+ pubkey = bytes .fromhex (pubkey_hex )
173
+ msg = bytes .fromhex (msg_hex )
174
+ sig = bytes .fromhex (sig_hex )
175
+ result = result_str == 'TRUE'
167
176
print ('\n Test vector' , ('#' + index ).rjust (3 , ' ' ) + ':' )
168
- if seckey != '' :
169
- seckey = bytes .fromhex (seckey )
177
+ if seckey_hex != '' :
178
+ seckey = bytes .fromhex (seckey_hex )
170
179
pubkey_actual = pubkey_gen (seckey )
171
180
if pubkey != pubkey_actual :
172
181
print (' * Failed key generation.' )
173
182
print (' Expected key:' , pubkey .hex ().upper ())
174
183
print (' Actual key:' , pubkey_actual .hex ().upper ())
175
- aux_rand = bytes .fromhex (aux_rand )
184
+ aux_rand = bytes .fromhex (aux_rand_hex )
176
185
try :
177
186
sig_actual = schnorr_sign (msg , seckey , aux_rand )
178
187
if sig == sig_actual :
@@ -207,7 +216,7 @@ def test_vectors():
207
216
#
208
217
import inspect
209
218
210
- def pretty (v ) :
219
+ def pretty (v : Any ) -> Any :
211
220
if isinstance (v , bytes ):
212
221
return '0x' + v .hex ()
213
222
if isinstance (v , int ):
@@ -216,9 +225,12 @@ def pretty(v):
216
225
return tuple (map (pretty , v ))
217
226
return v
218
227
219
- def debug_print_vars ():
228
+ def debug_print_vars () -> None :
220
229
if DEBUG :
221
- frame = inspect .currentframe ().f_back
230
+ current_frame = inspect .currentframe ()
231
+ assert current_frame is not None
232
+ frame = current_frame .f_back
233
+ assert frame is not None
222
234
print (' Variables in function ' , frame .f_code .co_name , ' at line ' , frame .f_lineno , ':' , sep = '' )
223
235
for var_name , var_val in frame .f_locals .items ():
224
236
print (' ' + var_name .rjust (11 , ' ' ), '==' , pretty (var_val ))
0 commit comments