Skip to content

Commit 2c64245

Browse files
committed
add support for external mu following lamps ietf draft
1 parent 96d1cd4 commit 2c64245

File tree

2 files changed

+96
-10
lines changed

2 files changed

+96
-10
lines changed

src/dilithium_py/ml_dsa/ml_dsa.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,13 @@ def _keygen_internal(self, zeta):
215215

216216
return pk, sk
217217

218-
def _sign_internal(self, sk_bytes, m, rnd):
218+
def _sign_internal(self, sk_bytes, m, rnd, external_mu=None):
219219
"""
220220
Deterministic algorithm to generate a signature for a formatted message
221221
M' following Algorithm 7 (FIPS 204)
222+
223+
Optionally allows for a pre-hashed message using `prehash_external_mu()`
224+
When `external_mu` is not `None`, then the message `m` must be `None`
222225
"""
223226
# unpack the secret key
224227
rho, K, tr, s1, s2, t0 = self._unpack_sk(sk_bytes)
@@ -232,7 +235,13 @@ def _sign_internal(self, sk_bytes, m, rnd):
232235
A_hat = self._expand_matrix_from_seed(rho)
233236

234237
# Set seeds and nonce (kappa)
235-
mu = self._h(tr + m, 64)
238+
if external_mu is None:
239+
mu = self._h(tr + m, 64)
240+
else:
241+
# NOTE: when using external mu, the validation of the length
242+
# of external_mu is handled by the function sign_external_mu
243+
assert m is None, "Signing using external mu, message will be ignored"
244+
mu = external_mu
236245
rho_prime = self._h(K + rnd + mu, 64)
237246

238247
kappa = 0
@@ -362,3 +371,52 @@ def verify(self, pk_bytes, m, sig_bytes, ctx=b""):
362371
m_prime = bytes([0]) + bytes([len(ctx)]) + ctx + m
363372

364373
return self._verify_internal(pk_bytes, m_prime, sig_bytes)
374+
375+
"""
376+
The following external mu functions are not in FIPS 204, but are in
377+
Appendix D of the following IETF draft and are included for experimentation
378+
for researchers and engineers
379+
380+
https://datatracker.ietf.org/doc/html/draft-ietf-lamps-dilithium-certificates-07
381+
"""
382+
383+
def prehash_external_mu(self, pk_bytes, m, ctx=b""):
384+
"""
385+
Prehash the message `m` with context `ctx` together with
386+
the public key for use with `sign_external_mu()`
387+
"""
388+
# Ensure the length of the context is as expected
389+
if len(ctx) > 255:
390+
raise ValueError(
391+
f"ctx bytes must have length at most 255, ctx has length {len(ctx) = }"
392+
)
393+
394+
# Format the message using the context
395+
m_prime = bytes([0]) + bytes([len(ctx)]) + ctx + m
396+
397+
# Compute mu by hashing the public key into the message
398+
tr = self._h(pk_bytes, 64)
399+
mu = self._h(tr + m_prime, 64)
400+
401+
return mu
402+
403+
def sign_external_mu(self, sk_bytes, external_mu, deterministic=False):
404+
"""
405+
Generates an ML-DSA signature of a message m given the prehash
406+
of the message `m` with an optional context
407+
"""
408+
# Ensure the length of the context is as expected
409+
if len(external_mu) != 64:
410+
raise ValueError(
411+
f"mu bytes must have length 64, mu has length {len(external_mu) = }"
412+
)
413+
414+
if deterministic:
415+
rnd = bytes([0] * 32)
416+
else:
417+
rnd = self.random_bytes(32)
418+
419+
# Compute the signature given external mu, we explicitly set the message
420+
# to None
421+
sig_bytes = self._sign_internal(sk_bytes, None, rnd, external_mu)
422+
return sig_bytes

tests/test_ml_dsa.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,39 @@ class TestMLDSA(unittest.TestCase):
1313
def generic_test_ml_dsa(self, ML_DSA, count=5):
1414
for _ in range(count):
1515
msg = b"Signed by ML_DSA" + os.urandom(16)
16+
ctx = os.urandom(128)
1617

1718
# Perform signature process
1819
pk, sk = ML_DSA.keygen()
19-
sig = ML_DSA.sign(sk, msg)
20-
check_verify = ML_DSA.verify(pk, msg, sig)
20+
sig = ML_DSA.sign(sk, msg, ctx=ctx)
21+
check_verify = ML_DSA.verify(pk, msg, sig, ctx=ctx)
22+
23+
# Sign with external_mu instead
24+
external_mu = ML_DSA.prehash_external_mu(pk, msg, ctx=ctx)
25+
sig_external_mu = ML_DSA.sign_external_mu(sk, external_mu)
26+
check_external_mu = ML_DSA.verify(pk, msg, sig_external_mu, ctx=ctx)
2127

2228
# Generate some fail cases
2329
pk_bad, _ = ML_DSA.keygen()
24-
check_wrong_pk = ML_DSA.verify(pk_bad, msg, sig)
25-
check_wrong_msg = ML_DSA.verify(pk, b"", sig)
30+
check_wrong_pk = ML_DSA.verify(pk_bad, msg, sig, ctx=ctx)
31+
check_wrong_msg = ML_DSA.verify(pk, b"", sig, ctx=ctx)
32+
check_no_ctx = ML_DSA.verify(pk, msg, sig)
2633

2734
# Check that signature works
2835
self.assertTrue(check_verify)
2936

37+
# Check that external_mu also works
38+
self.assertTrue(check_external_mu)
39+
3040
# Check changing the key breaks verify
3141
self.assertFalse(check_wrong_pk)
3242

3343
# Check changing the message breaks verify
3444
self.assertFalse(check_wrong_msg)
3545

46+
# Check removing the context breaks verify
47+
self.assertFalse(check_no_ctx)
48+
3649
def test_ml_dsa_44(self):
3750
self.generic_test_ml_dsa(ML_DSA_44)
3851

@@ -52,26 +65,41 @@ class TestMLDSADeterministic(unittest.TestCase):
5265
def generic_test_ml_dsa(self, ML_DSA, count=5):
5366
for _ in range(count):
5467
msg = b"Signed by ML_DSA" + os.urandom(16)
68+
ctx = os.urandom(128)
5569

5670
# Perform signature process
5771
pk, sk = ML_DSA.keygen()
58-
sig = ML_DSA.sign(sk, msg, deterministic=True)
59-
check_verify = ML_DSA.verify(pk, msg, sig)
72+
sig = ML_DSA.sign(sk, msg, ctx=ctx, deterministic=True)
73+
check_verify = ML_DSA.verify(pk, msg, sig, ctx=ctx)
74+
75+
# Sign with external_mu instead
76+
external_mu = ML_DSA.prehash_external_mu(pk, msg, ctx=ctx)
77+
sig_external_mu = ML_DSA.sign_external_mu(
78+
sk, external_mu, deterministic=True
79+
)
80+
check_external_mu = ML_DSA.verify(pk, msg, sig_external_mu, ctx=ctx)
6081

6182
# Generate some fail cases
6283
pk_bad, _ = ML_DSA.keygen()
63-
check_wrong_pk = ML_DSA.verify(pk_bad, msg, sig)
64-
check_wrong_msg = ML_DSA.verify(pk, b"", sig)
84+
check_wrong_pk = ML_DSA.verify(pk_bad, msg, sig, ctx=ctx)
85+
check_wrong_msg = ML_DSA.verify(pk, b"", sig, ctx=ctx)
86+
check_no_ctx = ML_DSA.verify(pk, msg, sig)
6587

6688
# Check that signature works
6789
self.assertTrue(check_verify)
6890

91+
# Check that external_mu also works
92+
self.assertTrue(check_external_mu)
93+
6994
# Check changing the key breaks verify
7095
self.assertFalse(check_wrong_pk)
7196

7297
# Check changing the message breaks verify
7398
self.assertFalse(check_wrong_msg)
7499

100+
# Check removing the context breaks verify
101+
self.assertFalse(check_no_ctx)
102+
75103
def test_ml_dsa_44(self):
76104
self.generic_test_ml_dsa(ML_DSA_44)
77105

0 commit comments

Comments
 (0)