Skip to content

Commit e854af8

Browse files
committed
start typing ml-dsa
1 parent c5338e3 commit e854af8

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

src/dilithium_py/ml_dsa/ml_dsa.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class ML_DSA:
11-
def __init__(self, parameter_set):
11+
def __init__(self, parameter_set: dict):
1212
self.d = parameter_set["d"]
1313
self.k = parameter_set["k"]
1414
self.l = parameter_set["l"]
@@ -28,7 +28,7 @@ def __init__(self, parameter_set):
2828
# use the method `set_drbg_seed()`
2929
self.random_bytes = os.urandom
3030

31-
def set_drbg_seed(self, seed):
31+
def set_drbg_seed(self, seed: bytes):
3232
"""
3333
Change entropy source to a DRBG and seed it with provided value.
3434
@@ -69,7 +69,7 @@ def _expand_matrix_from_seed(self, rho):
6969
Helper function which generates a element of size
7070
k x l from a seed `rho`.
7171
"""
72-
A_data = [[0 for _ in range(self.l)] for _ in range(self.k)]
72+
A_data = [[self.R.zero() for _ in range(self.l)] for _ in range(self.k)]
7373
for i in range(self.k):
7474
for j in range(self.l):
7575
A_data[i][j] = self.R.rejection_sample_ntt_poly(rho, i, j)
@@ -357,16 +357,16 @@ def _verify_internal(self, pk_bytes, m, sig_bytes):
357357

358358
return c_tilde == self._h(mu + w_prime_bytes, self.c_tilde_bytes)
359359

360-
def keygen(self):
360+
def keygen(self) -> tuple[bytes, bytes]:
361361
"""
362362
Generates a public-private key pair following
363363
Algorithm 1 (FIPS 204)
364364
"""
365365
zeta = self.random_bytes(32)
366366
pk, sk = self._keygen_internal(zeta)
367-
return pk, sk
367+
return (pk, sk)
368368

369-
def key_derive(self, seed):
369+
def key_derive(self, seed: bytes) -> tuple[bytes, bytes]:
370370
"""
371371
Derive a verification key and corresponding signing key
372372
following the approach from Section 6.1 (FIPS 204)
@@ -383,7 +383,9 @@ def key_derive(self, seed):
383383
pk, sk = self._keygen_internal(seed)
384384
return (pk, sk)
385385

386-
def sign(self, sk_bytes, m, ctx=b"", deterministic=False):
386+
def sign(
387+
self, sk: bytes, m: bytes, ctx: bytes = b"", deterministic: bool = False
388+
) -> bytes:
387389
"""
388390
Generates an ML-DSA signature following
389391
Algorithm 2 (FIPS 204)
@@ -402,10 +404,10 @@ def sign(self, sk_bytes, m, ctx=b"", deterministic=False):
402404
m_prime = bytes([0]) + bytes([len(ctx)]) + ctx + m
403405

404406
# Compute the signature of m_prime
405-
sig_bytes = self._sign_internal(sk_bytes, m_prime, rnd)
407+
sig_bytes = self._sign_internal(sk, m_prime, rnd)
406408
return sig_bytes
407409

408-
def verify(self, pk_bytes, m, sig_bytes, ctx=b""):
410+
def verify(self, pk: bytes, m: bytes, sig: bytes, ctx: bytes = b"") -> bool:
409411
"""
410412
Verifies a signature sigma for a message M following
411413
Algorithm 3 (FIPS 204)
@@ -418,21 +420,21 @@ def verify(self, pk_bytes, m, sig_bytes, ctx=b""):
418420
# Format the message using the context
419421
m_prime = bytes([0]) + bytes([len(ctx)]) + ctx + m
420422

421-
return self._verify_internal(pk_bytes, m_prime, sig_bytes)
423+
return self._verify_internal(pk, m_prime, sig)
422424

423425
"""
424426
The following additional function follows an outline from:
425427
https://github.com/aws/aws-lc/pull/2142
426428
which computes pk_bytes when only the sk_bytes are known.
427429
"""
428430

429-
def pk_from_sk(self, sk_bytes: bytes) -> bytes:
431+
def pk_from_sk(self, sk: bytes) -> bytes:
430432
"""
431433
Given the packed representation of a ML-DSA secret key,
432434
compute the corresponding packed public key bytes.
433435
"""
434436
# First unpack the secret key
435-
rho, K, tr, s1, s2, t0 = self._unpack_sk(sk_bytes)
437+
rho, _, tr, s1, s2, _ = self._unpack_sk(sk)
436438

437439
# Compute the matrix A from rho in NTT form
438440
A_hat = self._expand_matrix_from_seed(rho)
@@ -446,13 +448,13 @@ def pk_from_sk(self, sk_bytes: bytes) -> bytes:
446448
t1, _ = t.power_2_round(self.d)
447449

448450
# The packed public key is made from rho || t1
449-
pk_bytes = self._pack_pk(rho, t1)
451+
pk = self._pack_pk(rho, t1)
450452

451453
# Ensure the public key matches the hash within the secret key
452-
if tr != self._h(pk_bytes, 64):
454+
if tr != self._h(pk, 64):
453455
raise ValueError("malformed secret key")
454456

455-
return pk_bytes
457+
return pk
456458

457459
"""
458460
The following external mu functions are not in FIPS 204, but are in
@@ -462,7 +464,7 @@ def pk_from_sk(self, sk_bytes: bytes) -> bytes:
462464
https://datatracker.ietf.org/doc/html/draft-ietf-lamps-dilithium-certificates-07
463465
"""
464466

465-
def prehash_external_mu(self, pk_bytes, m, ctx=b""):
467+
def prehash_external_mu(self, pk: bytes, m: bytes, ctx: bytes = b"") -> bytes:
466468
"""
467469
Prehash the message `m` with context `ctx` together with
468470
the public key. For use with `sign_external_mu()`
@@ -472,22 +474,24 @@ def prehash_external_mu(self, pk_bytes, m, ctx=b""):
472474
raise ValueError(
473475
f"ctx bytes must have length at most 255, ctx has length {len(ctx) = }"
474476
)
475-
if len(pk_bytes) != self._pk_size():
477+
if len(pk) != self._pk_size():
476478
raise ValueError(
477479
f"Public key size doesn't match this ML-DSA object parameters,"
478-
f"received {len(pk_bytes) = }, expected: {self._pk_size()}"
480+
f"received {len(pk) = }, expected: {self._pk_size()}"
479481
)
480482

481483
# Format the message using the context
482484
m_prime = bytes([0]) + bytes([len(ctx)]) + ctx + m
483485

484486
# Compute mu by hashing the public key into the message
485-
tr = self._h(pk_bytes, 64)
487+
tr = self._h(pk, 64)
486488
mu = self._h(tr + m_prime, 64)
487489

488490
return mu
489491

490-
def sign_external_mu(self, sk_bytes, mu, deterministic=False):
492+
def sign_external_mu(
493+
self, sk: bytes, mu: bytes, deterministic: bool = False
494+
) -> bytes:
491495
"""
492496
Generates an ML-DSA signature of a message given the prehash
493497
mu = H(H(pk), M')
@@ -505,5 +509,5 @@ def sign_external_mu(self, sk_bytes, mu, deterministic=False):
505509

506510
# Compute the signature given external mu, we set the external_mu
507511
# to True
508-
sig_bytes = self._sign_internal(sk_bytes, mu, rnd, external_mu=True)
509-
return sig_bytes
512+
sig = self._sign_internal(sk, mu, rnd, external_mu=True)
513+
return sig

src/dilithium_py/polynomials/polynomials_generic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def __init__(self, q, n):
1313
self.n = n
1414
self.element = Polynomial
1515

16+
def zero(self):
17+
"""
18+
Return the value `0` of the polynomial ring
19+
"""
20+
return self([0])
21+
1622
def gen(self):
1723
"""
1824
Return the generator `x` of the polynomial ring

0 commit comments

Comments
 (0)