|
| 1 | +import hashlib |
| 2 | +import binascii |
| 3 | + |
| 4 | +p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F |
| 5 | +n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 |
| 6 | + |
| 7 | +# Points are tuples of X and Y coordinates and the point at infinity is |
| 8 | +# represented by the None keyword. |
| 9 | +G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) |
| 10 | + |
| 11 | +# This implementation can be sped up by storing the midstate after hashing |
| 12 | +# tag_hash instead of rehashing it all the time. |
| 13 | +def tagged_hash(tag, msg): |
| 14 | + tag_hash = hashlib.sha256(tag.encode()).digest() |
| 15 | + return hashlib.sha256(tag_hash + tag_hash + msg).digest() |
| 16 | + |
| 17 | +def is_infinity(P): |
| 18 | + return P is None |
| 19 | + |
| 20 | +def x(P): |
| 21 | + return P[0] |
| 22 | + |
| 23 | +def y(P): |
| 24 | + return P[1] |
| 25 | + |
| 26 | +def point_add(P1, P2): |
| 27 | + if (P1 is None): |
| 28 | + return P2 |
| 29 | + if (P2 is None): |
| 30 | + return P1 |
| 31 | + if (x(P1) == x(P2) and y(P1) != y(P2)): |
| 32 | + return None |
| 33 | + if (P1 == P2): |
| 34 | + lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p |
| 35 | + else: |
| 36 | + lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p |
| 37 | + x3 = (lam * lam - x(P1) - x(P2)) % p |
| 38 | + return (x3, (lam * (x(P1) - x3) - y(P1)) % p) |
| 39 | + |
| 40 | +def point_mul(P, n): |
| 41 | + R = None |
| 42 | + for i in range(256): |
| 43 | + if ((n >> i) & 1): |
| 44 | + R = point_add(R, P) |
| 45 | + P = point_add(P, P) |
| 46 | + return R |
| 47 | + |
| 48 | +def bytes_from_int(x): |
| 49 | + return x.to_bytes(32, byteorder="big") |
| 50 | + |
| 51 | +def bytes_from_point(P): |
| 52 | + return bytes_from_int(x(P)) |
| 53 | + |
| 54 | +def point_from_bytes(b): |
| 55 | + x = int_from_bytes(b) |
| 56 | + if x >= p: |
| 57 | + return None |
| 58 | + y_sq = (pow(x, 3, p) + 7) % p |
| 59 | + y = pow(y_sq, (p + 1) // 4, p) |
| 60 | + if pow(y, 2, p) != y_sq: |
| 61 | + return None |
| 62 | + return [x, y] |
| 63 | + |
| 64 | +def int_from_bytes(b): |
| 65 | + return int.from_bytes(b, byteorder="big") |
| 66 | + |
| 67 | +def hash_sha256(b): |
| 68 | + return hashlib.sha256(b).digest() |
| 69 | + |
| 70 | +def is_square(x): |
| 71 | + return pow(x, (p - 1) // 2, p) == 1 |
| 72 | + |
| 73 | +def has_square_y(P): |
| 74 | + return not is_infinity(P) and is_square(y(P)) |
| 75 | + |
| 76 | +def pubkey_gen(seckey): |
| 77 | + x = int_from_bytes(seckey) |
| 78 | + if not (1 <= x <= n - 1): |
| 79 | + raise ValueError('The secret key must be an integer in the range 1..n-1.') |
| 80 | + P = point_mul(G, x) |
| 81 | + return bytes_from_point(P) |
| 82 | + |
| 83 | +def schnorr_sign(msg, seckey0): |
| 84 | + if len(msg) != 32: |
| 85 | + raise ValueError('The message must be a 32-byte array.') |
| 86 | + seckey0 = int_from_bytes(seckey0) |
| 87 | + if not (1 <= seckey0 <= n - 1): |
| 88 | + raise ValueError('The secret key must be an integer in the range 1..n-1.') |
| 89 | + P = point_mul(G, seckey0) |
| 90 | + seckey = seckey0 if has_square_y(P) else n - seckey0 |
| 91 | + k0 = int_from_bytes(tagged_hash("BIPSchnorrDerive", bytes_from_int(seckey) + msg)) % n |
| 92 | + if k0 == 0: |
| 93 | + raise RuntimeError('Failure. This happens only with negligible probability.') |
| 94 | + R = point_mul(G, k0) |
| 95 | + k = n - k0 if not has_square_y(R) else k0 |
| 96 | + e = int_from_bytes(tagged_hash("BIPSchnorr", bytes_from_point(R) + bytes_from_point(P) + msg)) % n |
| 97 | + return bytes_from_point(R) + bytes_from_int((k + e * seckey) % n) |
| 98 | + |
| 99 | +def schnorr_verify(msg, pubkey, sig): |
| 100 | + if len(msg) != 32: |
| 101 | + raise ValueError('The message must be a 32-byte array.') |
| 102 | + if len(pubkey) != 32: |
| 103 | + raise ValueError('The public key must be a 32-byte array.') |
| 104 | + if len(sig) != 64: |
| 105 | + raise ValueError('The signature must be a 64-byte array.') |
| 106 | + P = point_from_bytes(pubkey) |
| 107 | + if (P is None): |
| 108 | + return False |
| 109 | + r = int_from_bytes(sig[0:32]) |
| 110 | + s = int_from_bytes(sig[32:64]) |
| 111 | + if (r >= p or s >= n): |
| 112 | + return False |
| 113 | + e = int_from_bytes(tagged_hash("BIPSchnorr", sig[0:32] + pubkey + msg)) % n |
| 114 | + R = point_add(point_mul(G, s), point_mul(P, n - e)) |
| 115 | + if R is None or not has_square_y(R) or x(R) != r: |
| 116 | + return False |
| 117 | + return True |
| 118 | + |
| 119 | +# |
| 120 | +# The following code is only used to verify the test vectors. |
| 121 | +# |
| 122 | +import csv |
| 123 | + |
| 124 | +def test_vectors(): |
| 125 | + all_passed = True |
| 126 | + with open('test-vectors.csv', newline='') as csvfile: |
| 127 | + reader = csv.reader(csvfile) |
| 128 | + reader.__next__() |
| 129 | + for row in reader: |
| 130 | + (index, seckey, pubkey, msg, sig, result, comment) = row |
| 131 | + pubkey = bytes.fromhex(pubkey) |
| 132 | + msg = bytes.fromhex(msg) |
| 133 | + sig = bytes.fromhex(sig) |
| 134 | + result = result == 'TRUE' |
| 135 | + print('\nTest vector #%-3i: ' % int(index)) |
| 136 | + if seckey != '': |
| 137 | + seckey = bytes.fromhex(seckey) |
| 138 | + pubkey_actual = pubkey_gen(seckey) |
| 139 | + if pubkey != pubkey_actual: |
| 140 | + print(' * Failed key generation.') |
| 141 | + print(' Expected key:', pubkey.hex().upper()) |
| 142 | + print(' Actual key:', pubkey_actual.hex().upper()) |
| 143 | + sig_actual = schnorr_sign(msg, seckey) |
| 144 | + if sig == sig_actual: |
| 145 | + print(' * Passed signing test.') |
| 146 | + else: |
| 147 | + print(' * Failed signing test.') |
| 148 | + print(' Expected signature:', sig.hex().upper()) |
| 149 | + print(' Actual signature:', sig_actual.hex().upper()) |
| 150 | + all_passed = False |
| 151 | + result_actual = schnorr_verify(msg, pubkey, sig) |
| 152 | + if result == result_actual: |
| 153 | + print(' * Passed verification test.') |
| 154 | + else: |
| 155 | + print(' * Failed verification test.') |
| 156 | + print(' Expected verification result:', result) |
| 157 | + print(' Actual verification result:', result_actual) |
| 158 | + if comment: |
| 159 | + print(' Comment:', comment) |
| 160 | + all_passed = False |
| 161 | + print() |
| 162 | + if all_passed: |
| 163 | + print('All test vectors passed.') |
| 164 | + else: |
| 165 | + print('Some test vectors failed.') |
| 166 | + return all_passed |
| 167 | + |
| 168 | +if __name__ == '__main__': |
| 169 | + test_vectors() |
0 commit comments