1
1
import os
2
- from ..modules .modules import ModuleDilithium
2
+ from ..modules .modules import Matrix , Module , Vector
3
3
4
4
try :
5
5
from xoflib import shake256
@@ -20,7 +20,7 @@ def __init__(self, parameter_set: dict):
20
20
self .beta = self .tau * self .eta
21
21
self .c_tilde_bytes = parameter_set ["c_tilde_bytes" ]
22
22
23
- self .M = ModuleDilithium ()
23
+ self .M = Module ()
24
24
self .R = self .M .ring
25
25
self .oid = parameter_set ["oid" ] if "oid" in parameter_set else None
26
26
@@ -58,13 +58,13 @@ def set_drbg_seed(self, seed: bytes):
58
58
"""
59
59
60
60
@staticmethod
61
- def _h (input_bytes , length ) :
61
+ def _h (input : bytes , length : int ) -> bytes :
62
62
"""
63
63
H: B^* -> B^*
64
64
"""
65
- return shake256 (input_bytes ).read (length )
65
+ return shake256 (input ).read (length )
66
66
67
- def _expand_matrix_from_seed (self , rho ) :
67
+ def _expand_matrix_from_seed (self , rho : bytes ) -> Matrix :
68
68
"""
69
69
Helper function which generates a element of size
70
70
k x l from a seed `rho`.
@@ -75,7 +75,7 @@ def _expand_matrix_from_seed(self, rho):
75
75
A_data [i ][j ] = self .R .rejection_sample_ntt_poly (rho , i , j )
76
76
return self .M (A_data )
77
77
78
- def _expand_vector_from_seed (self , rho_prime ) :
78
+ def _expand_vector_from_seed (self , rho_prime : bytes ) -> tuple [ Vector , Vector ] :
79
79
s1_elements = [
80
80
self .R .rejection_bounded_poly (rho_prime , i , self .eta ) for i in range (self .l )
81
81
]
@@ -88,24 +88,32 @@ def _expand_vector_from_seed(self, rho_prime):
88
88
s2 = self .M .vector (s2_elements )
89
89
return s1 , s2
90
90
91
- def _expand_mask_vector (self , rho , mu ) :
91
+ def _expand_mask_vector (self , rho : bytes , mu : int ) -> Vector :
92
92
elements = [
93
93
self .R .sample_mask_polynomial (rho , i , mu , self .gamma_1 )
94
94
for i in range (self .l )
95
95
]
96
96
return self .M .vector (elements )
97
97
98
98
@staticmethod
99
- def _pack_pk (rho , t1 ) :
99
+ def _pack_pk (rho : bytes , t1 : Vector ) -> bytes :
100
100
return rho + t1 .bit_pack_t1 ()
101
101
102
- def _pack_sk (self , rho , K , tr , s1 , s2 , t0 ):
102
+ def _pack_sk (
103
+ self ,
104
+ rho : bytes ,
105
+ k : bytes ,
106
+ tr : bytes ,
107
+ s1 : Vector ,
108
+ s2 : Vector ,
109
+ t0 : Vector ,
110
+ ) -> bytes :
103
111
s1_bytes = s1 .bit_pack_s (self .eta )
104
112
s2_bytes = s2 .bit_pack_s (self .eta )
105
113
t0_bytes = t0 .bit_pack_t0 ()
106
- return rho + K + tr + s1_bytes + s2_bytes + t0_bytes
114
+ return rho + k + tr + s1_bytes + s2_bytes + t0_bytes
107
115
108
- def _pack_h (self , h ) :
116
+ def _pack_h (self , h : Vector ) -> bytes :
109
117
non_zero_positions = [
110
118
[i for i , c in enumerate (poly .coeffs ) if c == 1 ]
111
119
for row in h ._data
@@ -121,20 +129,20 @@ def _pack_h(self, h):
121
129
packed .extend ([0 for _ in range (padding_len )])
122
130
return bytes (packed + offsets )
123
131
124
- def _pack_sig (self , c_tilde , z , h ) :
132
+ def _pack_sig (self , c_tilde : bytes , z : Vector , h : Vector ) -> bytes :
125
133
return c_tilde + z .bit_pack_z (self .gamma_1 ) + self ._pack_h (h )
126
134
127
- def _pk_size (self ):
135
+ def _pk_size (self ) -> int :
128
136
return 32 + 32 * self .k * 10
129
137
130
- def _unpack_pk (self , pk_bytes ) :
131
- if len (pk_bytes ) != self ._pk_size ():
138
+ def _unpack_pk (self , pk : bytes ) -> tuple [ bytes , Vector ] :
139
+ if len (pk ) != self ._pk_size ():
132
140
raise ValueError ("PK packed bytes is of the wrong length" )
133
- rho , t1_bytes = pk_bytes [:32 ], pk_bytes [32 :]
134
- t1 = self .M .bit_unpack_t1 (t1_bytes , self .k , 1 )
141
+ rho , t1_bytes = pk [:32 ], pk [32 :]
142
+ t1 = self .M .bit_unpack_t1 (t1_bytes , self .k )
135
143
return rho , t1
136
144
137
- def _sk_size (self ):
145
+ def _sk_size (self ) -> int :
138
146
if self .eta == 2 :
139
147
s_bytes = 96
140
148
else :
@@ -144,22 +152,24 @@ def _sk_size(self):
144
152
t0_len = 416 * self .k
145
153
return 2 * 32 + 64 + s1_len + s2_len + t0_len
146
154
147
- def _unpack_sk (self , sk_bytes ):
155
+ def _unpack_sk (
156
+ self , sk : bytes
157
+ ) -> tuple [bytes , bytes , bytes , Vector , Vector , Vector ]:
148
158
if self .eta == 2 :
149
159
s_bytes = 96
150
160
else :
151
161
s_bytes = 128
152
162
s1_len = s_bytes * self .l
153
163
s2_len = s_bytes * self .k
154
164
t0_len = 416 * self .k
155
- if len (sk_bytes ) != self ._sk_size ():
156
- raise ValueError ("SK packed bytes is of the wrong length" )
165
+ if len (sk ) != self ._sk_size ():
166
+ raise ValueError ("sk packed bytes is of the wrong length" )
157
167
158
168
# Split bytes between seeds and vectors
159
- sk_seed_bytes , sk_vec_bytes = sk_bytes [:128 ], sk_bytes [128 :]
169
+ sk_seed_bytes , sk_vec_bytes = sk [:128 ], sk [128 :]
160
170
161
171
# Unpack seed bytes
162
- rho , K , tr = (
172
+ rho , k , tr = (
163
173
sk_seed_bytes [:32 ],
164
174
sk_seed_bytes [32 :64 ],
165
175
sk_seed_bytes [64 :128 ],
@@ -171,50 +181,55 @@ def _unpack_sk(self, sk_bytes):
171
181
t0_bytes = sk_vec_bytes [- t0_len :]
172
182
173
183
# Unpack bytes to vectors
174
- s1 = self .M .bit_unpack_s (s1_bytes , self .l , 1 , self .eta )
175
- s2 = self .M .bit_unpack_s (s2_bytes , self .k , 1 , self .eta )
176
- t0 = self .M .bit_unpack_t0 (t0_bytes , self .k , 1 )
184
+ s1 = self .M .bit_unpack_s (s1_bytes , self .l , self .eta )
185
+ s2 = self .M .bit_unpack_s (s2_bytes , self .k , self .eta )
186
+ t0 = self .M .bit_unpack_t0 (t0_bytes , self .k )
177
187
178
- return rho , K , tr , s1 , s2 , t0
188
+ return rho , k , tr , s1 , s2 , t0
179
189
180
- def _unpack_h (self , h_bytes ) :
190
+ def _unpack_h (self , h_bytes : bytes ) -> Vector :
181
191
offsets = [0 ] + list (h_bytes [- self .k :])
182
- # check offsets are monotonic increasing
192
+
193
+ # ensure offsets are monotonic increasing
183
194
if any (offsets [i ] > offsets [i + 1 ] for i in range (len (offsets ) - 1 )):
184
- raise ValueError ("Offsets in h_bytes are not monotonic increasing" )
185
- # check offset[-1] is smaller than the length of h_bytes
195
+ raise ValueError ("offsets in h_bytes are not monotonically increasing" )
196
+
197
+ # ensure offset[-1] is smaller than the length of h_bytes
186
198
if offsets [- 1 ] > self .omega :
187
- raise ValueError ("Accumulate offset of hints exceeds omega" )
188
- # check zero fields are all zeros
199
+ raise ValueError ("accumulate offset of hints exceeds omega" )
200
+
201
+ # ensure zero fields are all zeros
189
202
if any (b != 0 for b in h_bytes [offsets [- 1 ] : self .omega ]):
190
- raise ValueError ("Non -zero fields in h_bytes are not all zeros" )
203
+ raise ValueError ("non -zero fields in h_bytes are not all zeros" )
191
204
192
205
non_zero_positions = [
193
206
list (h_bytes [offsets [i ] : offsets [i + 1 ]]) for i in range (self .k )
194
207
]
195
208
196
- matrix = []
209
+ vector_coeffs = []
197
210
for poly_non_zero in non_zero_positions :
198
211
coeffs = [0 for _ in range (256 )]
199
212
for i , non_zero in enumerate (poly_non_zero ):
200
213
if i > 0 and non_zero < poly_non_zero [i - 1 ]:
201
214
raise ValueError (
202
- "Non -zero positions in h_bytes are not monotonic increasing"
215
+ "non -zero positions in h_bytes are not monotonically increasing"
203
216
)
204
217
coeffs [non_zero ] = 1
205
- matrix .append ([self .R (coeffs )])
206
- return self .M (matrix )
218
+ vector_coeffs .append (self .R (coeffs ))
207
219
208
- def _unpack_sig (self , sig_bytes ):
209
- c_tilde = sig_bytes [: self .c_tilde_bytes ]
210
- z_bytes = sig_bytes [self .c_tilde_bytes : - (self .k + self .omega )]
211
- h_bytes = sig_bytes [- (self .k + self .omega ) :]
220
+ return self .M .vector (vector_coeffs )
212
221
213
- z = self .M .bit_unpack_z (z_bytes , self .l , 1 , self .gamma_1 )
222
+ def _unpack_sig (self , sig : bytes ) -> tuple [bytes , Vector , Vector ]:
223
+ c_tilde = sig [: self .c_tilde_bytes ]
224
+ z_bytes = sig [self .c_tilde_bytes : - (self .k + self .omega )]
225
+ h_bytes = sig [- (self .k + self .omega ) :]
226
+
227
+ z = self .M .bit_unpack_z (z_bytes , self .l , self .gamma_1 )
214
228
h = self ._unpack_h (h_bytes )
229
+
215
230
return c_tilde , z , h
216
231
217
- def _keygen_internal (self , zeta ) :
232
+ def _keygen_internal (self , zeta : bytes ) -> tuple [ bytes , bytes ] :
218
233
"""
219
234
Generates a public-private key pair from a seed following
220
235
Algorithm 6 (FIPS 204)
@@ -245,7 +260,9 @@ def _keygen_internal(self, zeta):
245
260
246
261
return pk , sk
247
262
248
- def _sign_internal (self , sk_bytes , m , rnd , external_mu = False ):
263
+ def _sign_internal (
264
+ self , sk : bytes , m : bytes , rnd : bytes , external_mu : bool = False
265
+ ) -> bytes :
249
266
"""
250
267
Deterministic algorithm to generate a signature for a formatted message
251
268
M' following Algorithm 7 (FIPS 204)
@@ -254,7 +271,7 @@ def _sign_internal(self, sk_bytes, m, rnd, external_mu=False):
254
271
the pre-hashed message `mu = prehash_external_mu()`
255
272
"""
256
273
# unpack the secret key
257
- rho , K , tr , s1 , s2 , t0 = self ._unpack_sk (sk_bytes )
274
+ rho , k , tr , s1 , s2 , t0 = self ._unpack_sk (sk )
258
275
259
276
# Precompute NTT representation
260
277
s1_hat = s1 .to_ntt ()
@@ -269,7 +286,7 @@ def _sign_internal(self, sk_bytes, m, rnd, external_mu=False):
269
286
mu = m
270
287
else :
271
288
mu = self ._h (tr + m , 64 )
272
- rho_prime = self ._h (K + rnd + mu , 64 )
289
+ rho_prime = self ._h (k + rnd + mu , 64 )
273
290
274
291
kappa = 0
275
292
alpha = self .gamma_2 << 1
@@ -318,16 +335,15 @@ def _sign_internal(self, sk_bytes, m, rnd, external_mu=False):
318
335
319
336
return self ._pack_sig (c_tilde , z , h )
320
337
321
- def _verify_internal (self , pk_bytes , m , sig_bytes ) :
338
+ def _verify_internal (self , pk : bytes , m : bytes , sig : bytes ) -> bool :
322
339
"""
323
340
Internal function to verify a signature sigma for a formatted message M'
324
341
following Algorithm 8 (FIPS 204)
325
342
"""
326
- rho , t1 = self ._unpack_pk (pk_bytes )
343
+ rho , t1 = self ._unpack_pk (pk )
327
344
try :
328
- c_tilde , z , h = self ._unpack_sig (sig_bytes )
345
+ c_tilde , z , h = self ._unpack_sig (sig )
329
346
except ValueError :
330
- # verify failed if malformed input signature
331
347
return False
332
348
333
349
if h .sum_hint () > self .omega :
@@ -338,7 +354,7 @@ def _verify_internal(self, pk_bytes, m, sig_bytes):
338
354
339
355
A_hat = self ._expand_matrix_from_seed (rho )
340
356
341
- tr = self ._h (pk_bytes , 64 )
357
+ tr = self ._h (pk , 64 )
342
358
mu = self ._h (tr + m , 64 )
343
359
c = self .R .sample_in_ball (c_tilde , self .tau )
344
360
0 commit comments