Skip to content

Commit 8bab6bc

Browse files
committed
refactor for explicit typing
1 parent e854af8 commit 8bab6bc

File tree

8 files changed

+271
-226
lines changed

8 files changed

+271
-226
lines changed

src/dilithium_py/dilithium/dilithium.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from ..modules.modules import ModuleDilithium
2+
from ..modules.modules import Module
33

44
try:
55
from xoflib import shake256
@@ -19,7 +19,7 @@ def __init__(self, parameter_set):
1919
self.gamma_2 = parameter_set["gamma_2"]
2020
self.beta = self.tau * self.eta
2121

22-
self.M = ModuleDilithium()
22+
self.M = Module()
2323
self.R = self.M.ring
2424

2525
# Use system randomness by default, for deterministic randomness
@@ -51,8 +51,8 @@ def set_drbg_seed(self, seed):
5151
)
5252

5353
"""
54-
H() uses Shake256 to hash data to 32 and 64 bytes in a
55-
few places in the code
54+
H() uses Shake256 to hash data to 32 and 64 bytes in a
55+
few places in the code
5656
"""
5757

5858
@staticmethod
@@ -67,7 +67,7 @@ def _expand_matrix_from_seed(self, rho):
6767
Helper function which generates a element of size
6868
k x l from a seed `rho`.
6969
"""
70-
A_data = [[0 for _ in range(self.l)] for _ in range(self.k)]
70+
A_data = [[self.R.zero() for _ in range(self.l)] for _ in range(self.k)]
7171
for i in range(self.k):
7272
for j in range(self.l):
7373
A_data[i][j] = self.R.rejection_sample_ntt_poly(rho, i, j)
@@ -124,7 +124,7 @@ def _pack_sig(self, c_tilde, z, h):
124124

125125
def _unpack_pk(self, pk_bytes):
126126
rho, t1_bytes = pk_bytes[:32], pk_bytes[32:]
127-
t1 = self.M.bit_unpack_t1(t1_bytes, self.k, 1)
127+
t1 = self.M.bit_unpack_t1(t1_bytes, self.k)
128128
return rho, t1
129129

130130
def _unpack_sk(self, sk_bytes):
@@ -154,9 +154,9 @@ def _unpack_sk(self, sk_bytes):
154154
t0_bytes = sk_vec_bytes[-t0_len:]
155155

156156
# Unpack bytes to vectors
157-
s1 = self.M.bit_unpack_s(s1_bytes, self.l, 1, self.eta)
158-
s2 = self.M.bit_unpack_s(s2_bytes, self.k, 1, self.eta)
159-
t0 = self.M.bit_unpack_t0(t0_bytes, self.k, 1)
157+
s1 = self.M.bit_unpack_s(s1_bytes, self.l, self.eta)
158+
s2 = self.M.bit_unpack_s(s2_bytes, self.k, self.eta)
159+
t0 = self.M.bit_unpack_t0(t0_bytes, self.k)
160160

161161
return rho, K, tr, s1, s2, t0
162162

@@ -179,7 +179,7 @@ def _unpack_sig(self, sig_bytes):
179179
z_bytes = sig_bytes[32 : -(self.k + self.omega)]
180180
h_bytes = sig_bytes[-(self.k + self.omega) :]
181181

182-
z = self.M.bit_unpack_z(z_bytes, self.l, 1, self.gamma_1)
182+
z = self.M.bit_unpack_z(z_bytes, self.l, self.gamma_1)
183183
h = self._unpack_h(h_bytes)
184184
return c_tilde, z, h
185185

src/dilithium_py/ml_dsa/ml_dsa.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from ..modules.modules import ModuleDilithium
2+
from ..modules.modules import Matrix, Module, Vector
33

44
try:
55
from xoflib import shake256
@@ -20,7 +20,7 @@ def __init__(self, parameter_set: dict):
2020
self.beta = self.tau * self.eta
2121
self.c_tilde_bytes = parameter_set["c_tilde_bytes"]
2222

23-
self.M = ModuleDilithium()
23+
self.M = Module()
2424
self.R = self.M.ring
2525
self.oid = parameter_set["oid"] if "oid" in parameter_set else None
2626

@@ -58,13 +58,13 @@ def set_drbg_seed(self, seed: bytes):
5858
"""
5959

6060
@staticmethod
61-
def _h(input_bytes, length):
61+
def _h(input: bytes, length: int) -> bytes:
6262
"""
6363
H: B^* -> B^*
6464
"""
65-
return shake256(input_bytes).read(length)
65+
return shake256(input).read(length)
6666

67-
def _expand_matrix_from_seed(self, rho):
67+
def _expand_matrix_from_seed(self, rho: bytes) -> Matrix:
6868
"""
6969
Helper function which generates a element of size
7070
k x l from a seed `rho`.
@@ -75,7 +75,7 @@ def _expand_matrix_from_seed(self, rho):
7575
A_data[i][j] = self.R.rejection_sample_ntt_poly(rho, i, j)
7676
return self.M(A_data)
7777

78-
def _expand_vector_from_seed(self, rho_prime):
78+
def _expand_vector_from_seed(self, rho_prime: bytes) -> tuple[Vector, Vector]:
7979
s1_elements = [
8080
self.R.rejection_bounded_poly(rho_prime, i, self.eta) for i in range(self.l)
8181
]
@@ -88,24 +88,32 @@ def _expand_vector_from_seed(self, rho_prime):
8888
s2 = self.M.vector(s2_elements)
8989
return s1, s2
9090

91-
def _expand_mask_vector(self, rho, mu):
91+
def _expand_mask_vector(self, rho: bytes, mu: int) -> Vector:
9292
elements = [
9393
self.R.sample_mask_polynomial(rho, i, mu, self.gamma_1)
9494
for i in range(self.l)
9595
]
9696
return self.M.vector(elements)
9797

9898
@staticmethod
99-
def _pack_pk(rho, t1):
99+
def _pack_pk(rho: bytes, t1: Vector) -> bytes:
100100
return rho + t1.bit_pack_t1()
101101

102-
def _pack_sk(self, rho, K, tr, s1, s2, t0):
102+
def _pack_sk(
103+
self,
104+
rho: bytes,
105+
k: bytes,
106+
tr: bytes,
107+
s1: Vector,
108+
s2: Vector,
109+
t0: Vector,
110+
) -> bytes:
103111
s1_bytes = s1.bit_pack_s(self.eta)
104112
s2_bytes = s2.bit_pack_s(self.eta)
105113
t0_bytes = t0.bit_pack_t0()
106-
return rho + K + tr + s1_bytes + s2_bytes + t0_bytes
114+
return rho + k + tr + s1_bytes + s2_bytes + t0_bytes
107115

108-
def _pack_h(self, h):
116+
def _pack_h(self, h: Vector) -> bytes:
109117
non_zero_positions = [
110118
[i for i, c in enumerate(poly.coeffs) if c == 1]
111119
for row in h._data
@@ -121,20 +129,20 @@ def _pack_h(self, h):
121129
packed.extend([0 for _ in range(padding_len)])
122130
return bytes(packed + offsets)
123131

124-
def _pack_sig(self, c_tilde, z, h):
132+
def _pack_sig(self, c_tilde: bytes, z: Vector, h: Vector) -> bytes:
125133
return c_tilde + z.bit_pack_z(self.gamma_1) + self._pack_h(h)
126134

127-
def _pk_size(self):
135+
def _pk_size(self) -> int:
128136
return 32 + 32 * self.k * 10
129137

130-
def _unpack_pk(self, pk_bytes):
131-
if len(pk_bytes) != self._pk_size():
138+
def _unpack_pk(self, pk: bytes) -> tuple[bytes, Vector]:
139+
if len(pk) != self._pk_size():
132140
raise ValueError("PK packed bytes is of the wrong length")
133-
rho, t1_bytes = pk_bytes[:32], pk_bytes[32:]
134-
t1 = self.M.bit_unpack_t1(t1_bytes, self.k, 1)
141+
rho, t1_bytes = pk[:32], pk[32:]
142+
t1 = self.M.bit_unpack_t1(t1_bytes, self.k)
135143
return rho, t1
136144

137-
def _sk_size(self):
145+
def _sk_size(self) -> int:
138146
if self.eta == 2:
139147
s_bytes = 96
140148
else:
@@ -144,22 +152,24 @@ def _sk_size(self):
144152
t0_len = 416 * self.k
145153
return 2 * 32 + 64 + s1_len + s2_len + t0_len
146154

147-
def _unpack_sk(self, sk_bytes):
155+
def _unpack_sk(
156+
self, sk: bytes
157+
) -> tuple[bytes, bytes, bytes, Vector, Vector, Vector]:
148158
if self.eta == 2:
149159
s_bytes = 96
150160
else:
151161
s_bytes = 128
152162
s1_len = s_bytes * self.l
153163
s2_len = s_bytes * self.k
154164
t0_len = 416 * self.k
155-
if len(sk_bytes) != self._sk_size():
156-
raise ValueError("SK packed bytes is of the wrong length")
165+
if len(sk) != self._sk_size():
166+
raise ValueError("sk packed bytes is of the wrong length")
157167

158168
# Split bytes between seeds and vectors
159-
sk_seed_bytes, sk_vec_bytes = sk_bytes[:128], sk_bytes[128:]
169+
sk_seed_bytes, sk_vec_bytes = sk[:128], sk[128:]
160170

161171
# Unpack seed bytes
162-
rho, K, tr = (
172+
rho, k, tr = (
163173
sk_seed_bytes[:32],
164174
sk_seed_bytes[32:64],
165175
sk_seed_bytes[64:128],
@@ -171,50 +181,55 @@ def _unpack_sk(self, sk_bytes):
171181
t0_bytes = sk_vec_bytes[-t0_len:]
172182

173183
# Unpack bytes to vectors
174-
s1 = self.M.bit_unpack_s(s1_bytes, self.l, 1, self.eta)
175-
s2 = self.M.bit_unpack_s(s2_bytes, self.k, 1, self.eta)
176-
t0 = self.M.bit_unpack_t0(t0_bytes, self.k, 1)
184+
s1 = self.M.bit_unpack_s(s1_bytes, self.l, self.eta)
185+
s2 = self.M.bit_unpack_s(s2_bytes, self.k, self.eta)
186+
t0 = self.M.bit_unpack_t0(t0_bytes, self.k)
177187

178-
return rho, K, tr, s1, s2, t0
188+
return rho, k, tr, s1, s2, t0
179189

180-
def _unpack_h(self, h_bytes):
190+
def _unpack_h(self, h_bytes: bytes) -> Vector:
181191
offsets = [0] + list(h_bytes[-self.k :])
182-
# check offsets are monotonic increasing
192+
193+
# ensure offsets are monotonic increasing
183194
if any(offsets[i] > offsets[i + 1] for i in range(len(offsets) - 1)):
184-
raise ValueError("Offsets in h_bytes are not monotonic increasing")
185-
# check offset[-1] is smaller than the length of h_bytes
195+
raise ValueError("offsets in h_bytes are not monotonically increasing")
196+
197+
# ensure offset[-1] is smaller than the length of h_bytes
186198
if offsets[-1] > self.omega:
187-
raise ValueError("Accumulate offset of hints exceeds omega")
188-
# check zero fields are all zeros
199+
raise ValueError("accumulate offset of hints exceeds omega")
200+
201+
# ensure zero fields are all zeros
189202
if any(b != 0 for b in h_bytes[offsets[-1] : self.omega]):
190-
raise ValueError("Non-zero fields in h_bytes are not all zeros")
203+
raise ValueError("non-zero fields in h_bytes are not all zeros")
191204

192205
non_zero_positions = [
193206
list(h_bytes[offsets[i] : offsets[i + 1]]) for i in range(self.k)
194207
]
195208

196-
matrix = []
209+
vector_coeffs = []
197210
for poly_non_zero in non_zero_positions:
198211
coeffs = [0 for _ in range(256)]
199212
for i, non_zero in enumerate(poly_non_zero):
200213
if i > 0 and non_zero < poly_non_zero[i - 1]:
201214
raise ValueError(
202-
"Non-zero positions in h_bytes are not monotonic increasing"
215+
"non-zero positions in h_bytes are not monotonically increasing"
203216
)
204217
coeffs[non_zero] = 1
205-
matrix.append([self.R(coeffs)])
206-
return self.M(matrix)
218+
vector_coeffs.append(self.R(coeffs))
207219

208-
def _unpack_sig(self, sig_bytes):
209-
c_tilde = sig_bytes[: self.c_tilde_bytes]
210-
z_bytes = sig_bytes[self.c_tilde_bytes : -(self.k + self.omega)]
211-
h_bytes = sig_bytes[-(self.k + self.omega) :]
220+
return self.M.vector(vector_coeffs)
212221

213-
z = self.M.bit_unpack_z(z_bytes, self.l, 1, self.gamma_1)
222+
def _unpack_sig(self, sig: bytes) -> tuple[bytes, Vector, Vector]:
223+
c_tilde = sig[: self.c_tilde_bytes]
224+
z_bytes = sig[self.c_tilde_bytes : -(self.k + self.omega)]
225+
h_bytes = sig[-(self.k + self.omega) :]
226+
227+
z = self.M.bit_unpack_z(z_bytes, self.l, self.gamma_1)
214228
h = self._unpack_h(h_bytes)
229+
215230
return c_tilde, z, h
216231

217-
def _keygen_internal(self, zeta):
232+
def _keygen_internal(self, zeta: bytes) -> tuple[bytes, bytes]:
218233
"""
219234
Generates a public-private key pair from a seed following
220235
Algorithm 6 (FIPS 204)
@@ -245,7 +260,9 @@ def _keygen_internal(self, zeta):
245260

246261
return pk, sk
247262

248-
def _sign_internal(self, sk_bytes, m, rnd, external_mu=False):
263+
def _sign_internal(
264+
self, sk: bytes, m: bytes, rnd: bytes, external_mu: bool = False
265+
) -> bytes:
249266
"""
250267
Deterministic algorithm to generate a signature for a formatted message
251268
M' following Algorithm 7 (FIPS 204)
@@ -254,7 +271,7 @@ def _sign_internal(self, sk_bytes, m, rnd, external_mu=False):
254271
the pre-hashed message `mu = prehash_external_mu()`
255272
"""
256273
# unpack the secret key
257-
rho, K, tr, s1, s2, t0 = self._unpack_sk(sk_bytes)
274+
rho, k, tr, s1, s2, t0 = self._unpack_sk(sk)
258275

259276
# Precompute NTT representation
260277
s1_hat = s1.to_ntt()
@@ -269,7 +286,7 @@ def _sign_internal(self, sk_bytes, m, rnd, external_mu=False):
269286
mu = m
270287
else:
271288
mu = self._h(tr + m, 64)
272-
rho_prime = self._h(K + rnd + mu, 64)
289+
rho_prime = self._h(k + rnd + mu, 64)
273290

274291
kappa = 0
275292
alpha = self.gamma_2 << 1
@@ -318,16 +335,15 @@ def _sign_internal(self, sk_bytes, m, rnd, external_mu=False):
318335

319336
return self._pack_sig(c_tilde, z, h)
320337

321-
def _verify_internal(self, pk_bytes, m, sig_bytes):
338+
def _verify_internal(self, pk: bytes, m: bytes, sig: bytes) -> bool:
322339
"""
323340
Internal function to verify a signature sigma for a formatted message M'
324341
following Algorithm 8 (FIPS 204)
325342
"""
326-
rho, t1 = self._unpack_pk(pk_bytes)
343+
rho, t1 = self._unpack_pk(pk)
327344
try:
328-
c_tilde, z, h = self._unpack_sig(sig_bytes)
345+
c_tilde, z, h = self._unpack_sig(sig)
329346
except ValueError:
330-
# verify failed if malformed input signature
331347
return False
332348

333349
if h.sum_hint() > self.omega:
@@ -338,7 +354,7 @@ def _verify_internal(self, pk_bytes, m, sig_bytes):
338354

339355
A_hat = self._expand_matrix_from_seed(rho)
340356

341-
tr = self._h(pk_bytes, 64)
357+
tr = self._h(pk, 64)
342358
mu = self._h(tr + m, 64)
343359
c = self.R.sample_in_ball(c_tilde, self.tau)
344360

0 commit comments

Comments
 (0)