Skip to content

Commit 9d0901d

Browse files
authored
Merge pull request #20 from GiacomoPope/explicit_types
Add type hints to ML-DSA and refactor to allow this to be easy
2 parents c5338e3 + f224a6f commit 9d0901d

File tree

8 files changed

+312
-257
lines changed

8 files changed

+312
-257
lines changed

src/dilithium_py/dilithium/dilithium.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from ..modules.modules import ModuleDilithium
2+
from ..modules.modules import Module
33

44
try:
55
from xoflib import shake256
@@ -19,7 +19,7 @@ def __init__(self, parameter_set):
1919
self.gamma_2 = parameter_set["gamma_2"]
2020
self.beta = self.tau * self.eta
2121

22-
self.M = ModuleDilithium()
22+
self.M = Module()
2323
self.R = self.M.ring
2424

2525
# Use system randomness by default, for deterministic randomness
@@ -51,8 +51,8 @@ def set_drbg_seed(self, seed):
5151
)
5252

5353
"""
54-
H() uses Shake256 to hash data to 32 and 64 bytes in a
55-
few places in the code
54+
H() uses Shake256 to hash data to 32 and 64 bytes in a
55+
few places in the code
5656
"""
5757

5858
@staticmethod
@@ -67,7 +67,7 @@ def _expand_matrix_from_seed(self, rho):
6767
Helper function which generates a element of size
6868
k x l from a seed `rho`.
6969
"""
70-
A_data = [[0 for _ in range(self.l)] for _ in range(self.k)]
70+
A_data = [[self.R.zero() for _ in range(self.l)] for _ in range(self.k)]
7171
for i in range(self.k):
7272
for j in range(self.l):
7373
A_data[i][j] = self.R.rejection_sample_ntt_poly(rho, i, j)
@@ -124,7 +124,7 @@ def _pack_sig(self, c_tilde, z, h):
124124

125125
def _unpack_pk(self, pk_bytes):
126126
rho, t1_bytes = pk_bytes[:32], pk_bytes[32:]
127-
t1 = self.M.bit_unpack_t1(t1_bytes, self.k, 1)
127+
t1 = self.M.bit_unpack_t1(t1_bytes, self.k)
128128
return rho, t1
129129

130130
def _unpack_sk(self, sk_bytes):
@@ -154,9 +154,9 @@ def _unpack_sk(self, sk_bytes):
154154
t0_bytes = sk_vec_bytes[-t0_len:]
155155

156156
# Unpack bytes to vectors
157-
s1 = self.M.bit_unpack_s(s1_bytes, self.l, 1, self.eta)
158-
s2 = self.M.bit_unpack_s(s2_bytes, self.k, 1, self.eta)
159-
t0 = self.M.bit_unpack_t0(t0_bytes, self.k, 1)
157+
s1 = self.M.bit_unpack_s(s1_bytes, self.l, self.eta)
158+
s2 = self.M.bit_unpack_s(s2_bytes, self.k, self.eta)
159+
t0 = self.M.bit_unpack_t0(t0_bytes, self.k)
160160

161161
return rho, K, tr, s1, s2, t0
162162

@@ -179,7 +179,7 @@ def _unpack_sig(self, sig_bytes):
179179
z_bytes = sig_bytes[32 : -(self.k + self.omega)]
180180
h_bytes = sig_bytes[-(self.k + self.omega) :]
181181

182-
z = self.M.bit_unpack_z(z_bytes, self.l, 1, self.gamma_1)
182+
z = self.M.bit_unpack_z(z_bytes, self.l, self.gamma_1)
183183
h = self._unpack_h(h_bytes)
184184
return c_tilde, z, h
185185

0 commit comments

Comments
 (0)